diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/main.c | 33 | ||||
-rw-r--r-- | src/parse.c | 353 | ||||
-rw-r--r-- | src/parse.h | 27 | ||||
-rw-r--r-- | src/util.c | 140 | ||||
-rw-r--r-- | src/util.h | 7 |
5 files changed, 472 insertions, 88 deletions
@@ -79,6 +79,7 @@ Layer * load_network(struct Configs cfg) return network; } + int main(int argc, char *argv[]) { char default_config_path[512], *env_config_path; struct Configs ml_configs = { @@ -113,36 +114,44 @@ int main(int argc, char *argv[]) { argv += optind; Layer *network = load_network(ml_configs); - - Array X, y; + Array in, out; + double *X = NULL, *y = NULL; + size_t X_shape[2], y_shape[2]; if (!strcmp("train", argv[0]) || !strcmp("retrain", argv[0])) { - file_read(argv[1], &X, &y, ml_configs, true); + file_read(argv[1], &in, &out, ml_configs, true); + X = data_preprocess(X_shape, in, ml_configs, true, false); + y = data_preprocess(y_shape, out, ml_configs, false, false); if (!strcmp("train", argv[0])) { - nn_network_init_weights(network, ml_configs.network_size, X.shape[1], true); + nn_network_init_weights(network, ml_configs.network_size, X_shape[1], true); } else if (!strcmp("retrain", argv[0])) { - nn_network_init_weights(network, ml_configs.network_size, X.shape[1], false); + nn_network_init_weights(network, ml_configs.network_size, X_shape[1], false); nn_network_read_weights(ml_configs.weights_filepath, network, ml_configs.network_size); } - nn_network_train(network, ml_configs, X.data, X.shape, y.data, y.shape); + nn_network_train(network, ml_configs, X, X_shape, y, y_shape); nn_network_write_weights(ml_configs.weights_filepath, network, ml_configs.network_size); fprintf(stderr, "weights saved on '%s'\n", ml_configs.weights_filepath); } else if (!strcmp("predict", argv[0])) { - file_read(argv[1], &X, &y, ml_configs, false); - nn_network_init_weights(network, ml_configs.network_size, X.shape[1], false); + file_read(argv[1], &in, &out, ml_configs, false); + X = data_preprocess(X_shape, in, ml_configs, true, false); + y = data_preprocess(y_shape, out, ml_configs, false, true); + nn_network_init_weights(network, ml_configs.network_size, X_shape[1], false); nn_network_read_weights(ml_configs.weights_filepath, network, ml_configs.network_size); - nn_network_predict(y.data, y.shape, X.data, X.shape, network, ml_configs.network_size); + nn_network_predict(y, y_shape, X, X_shape, network, ml_configs.network_size); // If neither output and file_format defined use input to define the output format if (!ml_configs.file_format && !ml_configs.out_filepath) { ml_configs.file_format = file_format_infer(ml_configs.in_filepath); } - file_write(X, y, ml_configs); + data_postprocess(&out, y, y_shape, ml_configs, false); + file_write(in, out, ml_configs); } else usage(1); nn_network_free_weights(network, ml_configs.network_size); free(network); - free(X.data); - free(y.data); + array_free(&in); + array_free(&out); + free(X); + free(y); util_free_config(&ml_configs); return 0; } diff --git a/src/parse.c b/src/parse.c index 0533a13..c3fe2a6 100644 --- a/src/parse.c +++ b/src/parse.c @@ -29,8 +29,7 @@ static void json_read( FILE *fp, Array *input, Array *out, - char *in_keys[], size_t in_keys_size, - char *out_keys[], size_t out_keys_size, + struct Configs cfgs, bool read_output ); @@ -47,11 +46,7 @@ static void csv_read( static void json_write( FILE *fp, Array input, Array out, - char *in_keys[], size_t in_keys_size, - char *out_keys[], size_t out_keys_size, - bool write_input, - int decimal_precision - ); + struct Configs cfgs); static void csv_write( FILE *fp, @@ -73,7 +68,6 @@ static void csv_readline_values( static void csv_keys2cols(size_t cols[], char *keys[], size_t keys_size); - void file_read( char *filepath, Array *input, Array *out, @@ -82,10 +76,6 @@ void file_read( { FILE *fp; - char **in_keys = ml_config.input_keys; - char **out_keys = ml_config.label_keys; - size_t n_in_keys = ml_config.n_input_keys; - size_t n_out_keys = ml_config.n_label_keys; char *file_format = ml_config.file_format; if (filepath != NULL && strcmp(filepath, "-")) { @@ -102,9 +92,11 @@ void file_read( file_format = file_format_infer(filepath); } + /* if (!strcmp(file_format, "csv")) csv_read(fp, input, out, in_keys, n_in_keys, out_keys, n_out_keys, read_output, false, ','); else if (!strcmp(file_format, "tsv")) csv_read(fp, input, out, in_keys, n_in_keys, out_keys, n_out_keys, read_output, false, '\t'); - else if (!strcmp(file_format, "json")) json_read(fp, input, out, in_keys, n_in_keys, out_keys, n_out_keys, read_output); + */ + if (!strcmp(file_format, "json")) json_read(fp, input, out, ml_config, read_output); else { die("file_read() Error: unable to parse %s files", file_format); } @@ -117,14 +109,7 @@ void file_write(Array input, Array out, struct Configs ml_config) FILE *fp; char *filepath = ml_config.out_filepath; - char **in_keys = ml_config.input_keys; - char **out_keys = ml_config.label_keys; - size_t n_in_keys = ml_config.n_input_keys; - size_t n_out_keys = ml_config.n_label_keys; - bool write_input = !ml_config.only_out; char *file_format = ml_config.file_format; - char decimal_precision = ml_config.decimal_precision; - if (filepath != NULL && strcmp(filepath, "-")) { fp = fopen(filepath, "w"); @@ -138,20 +123,165 @@ void file_write(Array input, Array out, struct Configs ml_config) if (fp == NULL) die("file_write() Error:"); - if (!strcmp(file_format, "json")) json_write(fp, input, out, in_keys, n_in_keys, out_keys, n_out_keys, write_input, decimal_precision); + if (!strcmp(file_format, "json")) json_write(fp, input, out, ml_config); + /* else if (!strcmp(file_format, "csv")) csv_write(fp, input, out, write_input, ',', decimal_precision); else if (!strcmp(file_format, "tsv")) csv_write(fp, input, out, write_input, '\t', decimal_precision); + */ else { die("file_write() Error: unable to write %s files", file_format); } fclose(fp); } +void data_postprocess( + Array *out, + double *data, size_t data_shape[2], + struct Configs cfgs, + bool is_input) +{ + char **keys = (is_input) ? cfgs.input_keys : cfgs.label_keys; + size_t n_keys = (is_input) ? cfgs.n_input_keys : cfgs.n_label_keys; + + char **categorical_keys = cfgs.categorical_keys; + size_t n_categorical_keys = cfgs.n_categorical_keys; + + char ***categorical_values = cfgs.categorical_values; + size_t *n_categorical_values = cfgs.n_categorical_values; + + size_t i, j, data_j; + for (data_j = j = 0; j < n_keys; j++) { + int k; + switch (out->type[j]) { + case ARRAY_NUMERICAL: + for (i = 0; i < data_shape[0]; i++) { + size_t data_index = i * data_shape[1] + data_j; + size_t index = i * out->shape[1] + j; + out->data[index].numeric = data[data_index]; + } + data_j++; + break; + case ARRAY_ONEHOT: + k = util_get_key_index(keys[j], categorical_keys, n_categorical_keys); + if (k == -1) { + die("data_postprocess() Error: field '%s' is not registered as categorical", + keys[j]); + } + for (i = 0; i < data_shape[0]; i++) { + size_t index = i * out->shape[1] + j; + size_t data_index = i * data_shape[1] + data_j; + int onehot_i = util_argmax(data + data_index, n_categorical_values[k]); + out->data[index].categorical = e_strdup(categorical_values[k][onehot_i]); + } + data_j += n_categorical_values[k]; + break; + default: + die("data_postprocess() Error: unexpected type received on '%s' field", keys[j]); + } + } +} + +double * data_preprocess( + size_t out_shape[2], + Array data, + struct Configs cfgs, + bool is_input, + bool only_allocate) +{ + double *out; + + char **keys = (is_input) ? cfgs.input_keys : cfgs.label_keys; + size_t n_keys = (is_input) ? cfgs.n_input_keys : cfgs.n_label_keys; + + char **categorical_keys = cfgs.categorical_keys; + size_t n_categorical_keys = cfgs.n_categorical_keys; + + char ***categorical_values = cfgs.categorical_values; + size_t *n_categorical_values = cfgs.n_categorical_values; + + size_t i, j, out_j; + + out_shape[0] = data.shape[0]; + out_shape[1] = 0; + for (i = 0; i < n_keys; i++) { + int n; + switch (data.type[i]) { + case ARRAY_NUMERICAL: + out_shape[1]++; + break; + case ARRAY_ONEHOT: + n = util_get_key_index(keys[i], categorical_keys, n_categorical_keys); + if (n == -1) die("data_preprocess() Error: field '%s' is not marked as categorical", keys[i]); + out_shape[1] += n_categorical_values[n]; + break; + default: + die("data_preprocess() Error: field '%s' has an unknown type", keys[i]); + break; + } + } + + out = ecalloc(out_shape[0] * out_shape[1], sizeof(double)); + if (only_allocate) return out; + + for (out_j = j = 0; j < data.shape[1]; j++) { + switch (data.type[j]) { + int k; + case ARRAY_NUMERICAL: + for (i = 0; i < out_shape[0]; i++) { + size_t index = i * data.shape[1] + j; + size_t out_index = i * out_shape[1] + out_j; + out[out_index] = data.data[index].numeric; + } + out_j++; + break; + case ARRAY_ONEHOT: + k = util_get_key_index(keys[j], categorical_keys, n_categorical_keys); + for (i = 0; i < out_shape[0]; i++) { + int onehot_i; + size_t index = i * data.shape[1] + j; + onehot_i = util_get_key_index(data.data[index].categorical, + categorical_values[k], + n_categorical_values[k]); + if (onehot_i == -1) { + die("data_preprocess() Error: unexpected '%s' value found", + data.data[index].categorical); + } + size_t out_index = i * out_shape[1] + out_j + onehot_i; + out[out_index] = 1.0; + } + out_j += n_categorical_values[k]; + break; + default: + die("data_preprocess() Error: field '%s' has an unknown type", keys[j]); + } + } + + return out; +} + +void array_free(Array *x) { + size_t i, j, index; + for (j = 1; j < x->shape[1]; j++) { + switch (x->type[j]) { + case ARRAY_ORDINAL: + case ARRAY_ONEHOT: + for (i = 0; i < x->shape[0]; i++) { + index = x->shape[1] * i + j; + free(x->data[index].categorical); + } + break; + default: + break; + } + } + free(x->type); + free(x->data); +} + void json_read( FILE *fp, Array *input, Array *out, - char *in_keys[], size_t n_input_keys, - char *out_keys[], size_t n_out_keys, + struct Configs cfgs, bool read_output) { static char fp_buffer[MAX_FILE_SIZE]; @@ -159,8 +289,15 @@ void json_read( json_object *json_obj, *item, *value; json_type obj_type; + char **in_keys = cfgs.input_keys; + char **out_keys = cfgs.label_keys; + char **onehot_keys = cfgs.onehot_keys; + size_t n_input_keys = cfgs.n_input_keys; + size_t n_out_keys = cfgs.n_label_keys; + size_t n_onehot_keys = cfgs.n_onehot_keys; - if (fp == NULL) goto json_read_error; + + if (fp == NULL) die("json_read() Error:"); i = 0; do { @@ -178,13 +315,28 @@ void json_read( input->shape[0] = (size_t)json_obj_length; input->shape[1] = n_input_keys; - input->data = calloc(input->shape[0] * input->shape[1], sizeof(input->data[0])); + input->type = ecalloc(input->shape[1], sizeof(enum ArrayType)); + input->data = ecalloc(input->shape[0] * input->shape[1], sizeof(input->data[0])); out->shape[0] = (size_t)json_obj_length; out->shape[1] = n_out_keys; - out->data = calloc(out->shape[0] * out->shape[1], sizeof(out->data[0])); + out->type = ecalloc(out->shape[1], sizeof(enum ArrayType)); + out->data = ecalloc(out->shape[0] * out->shape[1], sizeof(out->data[0])); + + for (i = 0; i < n_onehot_keys; i++) { + for (j = 0; j < n_input_keys; j++) { + if (!strcmp(onehot_keys[i], in_keys[j])) { + input->type[j] = ARRAY_ONEHOT; + } + } + + for (j = 0; j < n_out_keys; j++) { + if (!strcmp(onehot_keys[i], out_keys[j])) { + out->type[j] = ARRAY_ONEHOT; + } + } + } - if (!input->data || !out->data) goto json_read_error; for (i = 0; i < json_object_array_length(json_obj); i++) { item = json_object_array_get_idx(json_obj, i); @@ -201,15 +353,30 @@ void json_read( for (j = 0; j < n_input_keys; j++) { value = json_object_object_get(item, in_keys[j]); obj_type = json_object_get_type(value); - switch (obj_type) { - case json_type_double: - case json_type_int: - index = n_input_keys * i + j; - input->data[index] = json_object_get_double(value); + index = n_input_keys * i + j; + switch (input->type[j]) { + case ARRAY_NUMERICAL: + switch (obj_type) { + case json_type_int: + case json_type_double: + input->data[index].numeric = json_object_get_double(value); + break; + default: + die("json_read() Error: unexpected JSON data received, expecting a number"); + } break; - default: - die("json_read() Error: unexpected JSON data received, expecting a number"); + case ARRAY_ONEHOT: + switch (obj_type) { + case json_type_int: + case json_type_string: + input->data[index].categorical = e_strdup(json_object_get_string(value)); + break; + default: + die("json_read() Error: unexpected JSON data received, expecting a string or integer"); + } break; + default: + die("json_read() Error: preprocess field type '%s' is not implemented", in_keys[j]); } } @@ -218,28 +385,40 @@ void json_read( for (j = 0; j < n_out_keys; j++) { value = json_object_object_get(item, out_keys[j]); obj_type = json_object_get_type(value); - switch (obj_type) { - case json_type_double: - case json_type_int: - index = n_out_keys * i + j; - out->data[index] = json_object_get_double(value); + index = n_out_keys * i + j; + switch (out->type[j]) { + case ARRAY_NUMERICAL: + switch (obj_type) { + case json_type_int: + case json_type_double: + out->data[index].numeric = json_object_get_double(value); + break; + default: + die("json_read() Error: unexpected JSON data received, expecting a number"); + } break; - default: - die("json_read() Error: unexpected JSON data received, expecting a number"); + case ARRAY_ONEHOT: + switch (obj_type) { + case json_type_int: + case json_type_string: + out->data[index].categorical = e_strdup(json_object_get_string(value)); + break; + default: + die("json_read() Error: unexpected JSON data received, expecting string or integer"); + } break; + default: + die("json_read() Error: preprocess field type '%s' is not implemented", out_keys[j]); } } } json_object_put(json_obj); return; - -json_read_error: - perror("json_read() Error"); - exit(1); } +/* void csv_read( FILE *fp, Array *input, Array *out, @@ -304,51 +483,78 @@ void csv_read( free(out_cols); return; } +*/ void json_write( FILE *fp, Array input, Array out, - char *in_keys[], size_t in_keys_size, - char *out_keys[], size_t out_keys_size, - bool write_input, - int decimal_precision) + struct Configs cfgs) { - fprintf(fp, "[\n"); - - if (in_keys_size != input.shape[1] && write_input) { - die("json_write() Error: there are more keys (%zu) than input columns (%zu)", - in_keys_size, input.shape[1]); + char **in_keys = cfgs.input_keys; + char **out_keys = cfgs.label_keys; + size_t n_in_keys = cfgs.n_input_keys; + size_t n_out_keys = cfgs.n_label_keys; + bool write_input = !cfgs.only_out; + int decimal_precision = cfgs.decimal_precision; + + json_object *root = json_object_new_array(); + if (!root) { + die("json_write() Error: Unable to create json_data"); } - if (out_keys_size != out.shape[1]) { - die("json_write() Error: there are more keys (%zu) than output columns (%zu)", - out_keys_size, out.shape[1]); - } + if (n_in_keys != input.shape[1]) + die("json_write() Error: input keys and data columns have different sizes"); + if (n_out_keys != out.shape[1]) + die("json_write() Error: output keys and data columns have different sizes"); - for (size_t i = 0; i < input.shape[0]; i++) { - fprintf(fp, " {\n"); + size_t i, j; + for (i = 0; i < input.shape[0]; i++) { + json_object *obj = json_object_new_object(); if (write_input) { - for (size_t j = 0; j < input.shape[1]; j++) { - size_t index = input.shape[1] * i + j; - fprintf(fp, " \"%s\": %g,\n", in_keys[j], input.data[index]); + for (j = 0; j < input.shape[1]; j++) { + char buffer[128]; + size_t index = i * input.shape[1] + j; + switch (input.type[j]) { + case ARRAY_NUMERICAL: + sprintf(buffer, "%g", input.data[index].numeric); + json_object_object_add(obj, in_keys[j], json_object_new_double_s(input.data[index].numeric, buffer)); + break; + case ARRAY_ONEHOT: + json_object_object_add(obj, in_keys[j], json_object_new_string(input.data[index].categorical)); + break; + default: + die("json_write(): Unexpected value received"); + } } } - for (size_t j = 0; j < out.shape[1]; j++) { - size_t index = out.shape[1] * i + j; - fprintf(fp, " \"%s\": %.*g", out_keys[j], decimal_precision, out.data[index]); - if (j == out.shape[1] - 1) fprintf(fp, "\n"); - else fprintf(fp, ",\n"); + for (j = 0; j < out.shape[1]; j++) { + size_t index = i * out.shape[1] + j; + char buffer[32]; + switch (out.type[j]) { + case ARRAY_NUMERICAL: + sprintf(buffer, "%.*g", decimal_precision, out.data[index].numeric); + json_object_object_add(obj, out_keys[j], json_object_new_double_s(out.data[index].numeric, buffer)); + break; + case ARRAY_ONEHOT: + json_object_object_add(obj, out_keys[j], json_object_new_string(out.data[index].categorical)); + break; + default: + die("json_write(): Unexpected value received"); + } } - - if (i == input.shape[0] - 1) fprintf(fp, " }\n"); - else fprintf(fp, " },\n"); + json_object_array_add(root, obj); + } + int ret = fprintf(fp, "%s", json_object_to_json_string_ext(root, JSON_C_TO_STRING_PRETTY | JSON_C_TO_STRING_SPACED)); + if (ret == -1) { + die("json_write() Error: unable to write json data"); } - fprintf(fp, "]\n"); + json_object_put(root); } +/* void csv_write( FILE *fp, Array input, Array out, @@ -373,6 +579,7 @@ void csv_write( fprintf(fp, "\n"); } } +*/ void csv_columns_select( double *dst_row, double *src_row, @@ -457,7 +664,6 @@ char * file_format_infer(char *filename) return file_format; } - #ifdef PARSE_TEST #include <assert.h> #include <string.h> @@ -546,4 +752,5 @@ int main(int argc, char *argv[]) { return 0; } + #endif diff --git a/src/parse.h b/src/parse.h index 07f740b..3f49c15 100644 --- a/src/parse.h +++ b/src/parse.h @@ -6,13 +6,38 @@ #include "util.h" +enum ArrayType { + ARRAY_NUMERICAL, + ARRAY_ORDINAL, + ARRAY_ONEHOT +}; + +union ArrayValue { + double numeric; + char *categorical; +}; + typedef struct Array { - double *data; + enum ArrayType *type; + union ArrayValue *data; size_t shape[2]; } Array; +void array_free(Array *x); void file_read(char *filepath, Array *input, Array *out, struct Configs configs, bool read_output); void file_write(Array input, Array out, struct Configs ml_configs); char * file_format_infer(char *filename); +double * data_preprocess( + size_t out_shape[2], + Array data, + struct Configs configs, + bool is_input, + bool only_allocate); + +void data_postprocess( + Array *out, + double *data, size_t data_shape[2], + struct Configs cfgs, + bool is_input); #endif @@ -25,9 +25,13 @@ #include "util.h" #define BUFFER_SIZE 1024 + +static int cmpstringp(const void *, const void *); static char ** config_read_values(size_t *n_out_keys, char *first_value, char **strtok_ptr); static void load_net_cfgs(struct Configs *cfg, char *key, char *value, char *strtok_ptr, char *filepath); static void load_lyr_cfgs(struct Configs *cfg, char *key, char *value, char *filepath); +static void load_categorical_cfgs(struct Configs *cfg, char *key, char *value, char *strtok_ptr); +static void load_preprocess_cfgs(struct Configs *cfg, char *key, char *value, char *strtok_ptr, char *filepath); static void add_lyr(struct Configs *cfg); void die(const char *fmt, ...) @@ -207,11 +211,34 @@ void util_free_config(struct Configs *ml) free(ml->activations[i]); free(ml->activations); } + + if (ml->onehot_keys != NULL) { + for (size_t i = 0; i < ml->n_onehot_keys; i++) + free(ml->onehot_keys[i]); + free(ml->onehot_keys); + } + + if (ml->categorical_keys != NULL) { + for (size_t i = 0; i < ml->n_categorical_keys; i++) + free(ml->categorical_keys[i]); + free(ml->categorical_keys); + } + + if (ml->categorical_values != NULL) { + for (size_t i = 0; i < ml->n_categorical_keys; i++) { + for (size_t j = 0; j < ml->n_categorical_values[i]; j++) { + free(ml->categorical_values[i][j]); + } + free(ml->categorical_values[i]); + } + free(ml->n_categorical_values); + free(ml->categorical_values); + } } void util_load_config(struct Configs *ml, char *filepath) { - enum Section {NET, LAYER, OUT_LAYER}; + enum Section {NET, PREPROCESSING, CATEGORICAL, LAYER, OUT_LAYER}; enum Section section; int line_number = 0; char line_buffer[BUFFER_SIZE], line_buffer_original[BUFFER_SIZE]; @@ -234,6 +261,10 @@ void util_load_config(struct Configs *ml, char *filepath) ml->network_size++; add_lyr(ml); ml->neurons[ml->network_size-1] = ml->n_label_keys; + } else if (!strcmp("preprocessing", token_buffer)) { + section = PREPROCESSING; + } else if (!strcmp("categorical_fields", token_buffer)) { + section = CATEGORICAL; } else { die("util_load_config() Error: Unknown section '%s' on %s", line_buffer, filepath); @@ -277,6 +308,12 @@ void util_load_config(struct Configs *ml, char *filepath) case NET: load_net_cfgs(ml, key, value, ptr_buffer, filepath); break; + case PREPROCESSING: + load_preprocess_cfgs(ml, key, value, ptr_buffer, filepath); + break; + case CATEGORICAL: + load_categorical_cfgs(ml, key, value, ptr_buffer); + break; case LAYER: load_lyr_cfgs(ml, key, value, filepath); break; @@ -292,6 +329,48 @@ void util_load_config(struct Configs *ml, char *filepath) break; } } + + /* Checks categorical_keys in label_keys or input_keys or onehot_keys*/ + size_t i,j,k; + for (i = 0; i < ml->n_categorical_keys; i++) { + int ret; + ret = util_get_key_index(ml->categorical_keys[i], ml->input_keys, ml->n_input_keys); + if (ret >= 0) continue; + ret = util_get_key_index(ml->categorical_keys[i], ml->label_keys, ml->n_label_keys); + if (ret == -1) { + die("util_load_config() Error: field '%s' does not exist", ml->categorical_keys[i]); + } + + ret = util_get_key_index(ml->categorical_keys[i], ml->onehot_keys, ml->n_onehot_keys); + if (ret >= 0) continue; + die("util_load_config() Error: field '%s' must be encoded", ml->categorical_keys[i]); + } + + /* Check onehot_keys in categorical_keys */ + for (i = 0; i < ml->n_onehot_keys; i++) { + int ret = util_get_key_index(ml->onehot_keys[i], + ml->categorical_keys, + ml->n_categorical_keys); + if (ret >= 0) continue; + die("util_load_config() Error: one hot field '%s' is not defined as categorical", ml->onehot_keys[i]); + } + + /* Determine out layer neurons */ + size_t *out_layer_neurons = ml->neurons + ml->network_size - 1; + *out_layer_neurons = 0; + for (i = 0; i < ml->n_label_keys; i++) { + int ret = 1; + + for (j = 0; ret && j < ml->n_categorical_keys; j++) + ret = strcmp(ml->categorical_keys[j], ml->label_keys[i]); + + for (k = 0; ret && k < ml->n_onehot_keys; k++) + ret = strcmp(ml->onehot_keys[k], ml->label_keys[i]); + + *out_layer_neurons += (!ret) ? ml->n_categorical_values[i] : 1; + } + + fclose(fp); return; @@ -331,7 +410,40 @@ void load_net_cfgs(struct Configs *cfg, char *key, char *value, char *strtok_ptr else if (!strcmp(key, "alpha")) cfg->alpha = (double)atof(value); else if (!strcmp(key, "inputs")) cfg->input_keys = config_read_values(&(cfg->n_input_keys), value, &strtok_ptr); else if (!strcmp(key, "labels")) cfg->label_keys = config_read_values(&(cfg->n_label_keys), value, &strtok_ptr); - else die("util_load_config() Error: Unknown parameter '%s' on file %s.", key, filepath); + else die("util_load_config() Error: Invalid parameter '%s' in [net] section on file %s.", key, filepath); +} + +void load_preprocess_cfgs(struct Configs *cfg, char *key, char *value, char *strtok_ptr, char *filepath) +{ + if (!strcmp(key, "onehot")) cfg->onehot_keys = config_read_values(&cfg->n_onehot_keys, value, &strtok_ptr); + else die("util_load_config() Error: Invalid parameter '%s' in [preprocess] section on file %s", key, filepath); +} + + +void load_categorical_cfgs( + struct Configs *cfg, + char *key, char *value, + char *strtok_ptr) +{ + size_t size, *value_size; + + size = cfg->n_categorical_keys; + if (cfg->n_categorical_keys == 0) { + cfg->categorical_keys = ecalloc(1, sizeof(char *)); + cfg->categorical_values = ecalloc(1, sizeof(char **)); + cfg->n_categorical_values = ecalloc(1, sizeof(size_t)); + cfg->n_categorical_keys++; + } else { + cfg->categorical_keys = erealloc(cfg->categorical_keys, sizeof(char *) * (size + 1)); + cfg->categorical_values = erealloc(cfg->categorical_values, sizeof(char *) * (size + 1)); + cfg->n_categorical_values = erealloc(cfg->n_categorical_values, sizeof(size_t) * (size + 1)); + cfg->n_categorical_keys++; + } + + value_size = cfg->n_categorical_values + size; + cfg->categorical_keys[size] = e_strdup(key); + cfg->categorical_values[size] = config_read_values(value_size, value, &strtok_ptr); + qsort(cfg->categorical_values[size], *value_size, sizeof(char *), cmpstringp); } char ** config_read_values(size_t *n_out_keys, char *first_value, char **strtok_ptr) @@ -348,4 +460,28 @@ char ** config_read_values(size_t *n_out_keys, char *first_value, char **strtok_ } return out_keys; } + +int util_get_key_index(char *key, char **keys, size_t n_keys) +{ + int i; + for (i = 0; (size_t)i < n_keys; i++) + if (!strcmp(key, keys[i])) return i; + return -1; +} + +int util_argmax(double *values, size_t n_values) +{ + double value = values[0]; + size_t i, j; + for (i = j = 0; i < n_values; i++) { + if (values[i] > value) j = i; + value = values[i]; + } + return j; +} + +int cmpstringp(const void *p1, const void *p2) +{ + return strcmp(*(const char **) p1, *(const char **) p2); +} #undef BUFFER_SIZE @@ -12,9 +12,14 @@ struct Configs { char *loss; char **input_keys, **label_keys; size_t n_input_keys, n_label_keys; + char **categorical_keys, ***categorical_values; + size_t n_categorical_keys, *n_categorical_values; char *weights_filepath; char *config_filepath; bool shuffle; + /* preprocessing */ + char **onehot_keys; + size_t n_onehot_keys; /* cli cfgs */ char *file_format; char *in_filepath; @@ -32,6 +37,8 @@ void die(const char *fmt, ...); void *ecalloc(size_t nmemb, size_t size); void *erealloc(void *ptr, size_t size); char *e_strdup(const char *s); +int util_get_key_index(char *key, char **keys, size_t n_keys); +int util_argmax(double *values, size_t n_values); void util_load_cli(struct Configs *ml, int argc, char *argv[]); void util_load_config(struct Configs *ml, char *filepath); void util_free_config(struct Configs *ml); |