diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/main.c | 25 | ||||
-rw-r--r-- | src/parse.c | 7 | ||||
-rw-r--r-- | src/util.c | 2 |
3 files changed, 25 insertions, 9 deletions
@@ -51,7 +51,7 @@ void load_config(struct Configs *cfg, int n_args, ...) } else break; } va_end(ap); - die("load_config() Error:"); + die("load_config('%s') Error:", filepath); } Layer * load_network(struct Configs cfg) @@ -88,11 +88,11 @@ struct Cost load_loss(struct Configs cfg) } int main(int argc, char *argv[]) { - char default_config_path[512]; + char default_config_path[512], *env_config_path; struct Configs ml_configs = { .epochs = 100, .alpha = 1e-5, - .config_filepath = "utils/settings.cfg", + .config_filepath = "", .network_size = 0, .only_out = false, .decimal_precision = -1, @@ -103,9 +103,15 @@ int main(int argc, char *argv[]) { // First past to check if --config option was put util_load_cli(&ml_configs, argc, argv); optind = 1; + // Load configs with different possible paths sprintf(default_config_path, "%s/%s", getenv("HOME"), ".config/ml/ml.cfg"); - load_config(&ml_configs, 2, ml_configs.config_filepath, default_config_path); + env_config_path = (getenv("ML_CONFIG_PATH"))? getenv("ML_CONFIG_PATH"):""; + + load_config(&ml_configs, 3, + ml_configs.config_filepath, + env_config_path, + default_config_path); // re-read cli options again, to overwrite file configuration options util_load_cli(&ml_configs, argc, argv); @@ -115,12 +121,17 @@ int main(int argc, char *argv[]) { Layer *network = load_network(ml_configs); Array X, y; - if (!strcmp("train", argv[0])) { + if (!strcmp("train", argv[0]) || !strcmp("retrain", argv[0])) { file_read(argv[1], &X, &y, ml_configs.input_keys, ml_configs.n_input_keys, ml_configs.label_keys, ml_configs.n_label_keys, true, ml_configs.file_format); - nn_network_init_weights(network, ml_configs.network_size, X.shape[1], true); + if (!strcmp("train", argv[0])) { + nn_network_init_weights(network, ml_configs.network_size, X.shape[1], true); + } else if (!strcmp("retrain", argv[0])) { + nn_network_init_weights(network, ml_configs.network_size, X.shape[1], false); + nn_network_read_weights(ml_configs.weights_filepath, network, ml_configs.network_size); + } nn_network_train( network, ml_configs.network_size, X.data, X.shape, @@ -139,7 +150,7 @@ int main(int argc, char *argv[]) { nn_network_read_weights(ml_configs.weights_filepath, network, ml_configs.network_size); nn_network_predict(y.data, y.shape, X.data, X.shape, network, ml_configs.network_size); - // If neither output and file_format defined use input to define the format + // If neither output and file_format defined use input to define the output format if (!ml_configs.file_format && !ml_configs.out_filepath) { ml_configs.file_format = file_format_infer(ml_configs.in_filepath); } diff --git a/src/parse.c b/src/parse.c index cea595b..a06f0f3 100644 --- a/src/parse.c +++ b/src/parse.c @@ -187,6 +187,11 @@ void json_read( die("json_read() Error: unexpected JSON data received, expecting an object"); } + if ((size_t)json_object_object_length(item) < n_input_keys + n_out_keys) { + die("json_read() Error: the number of keys required is greater " + "than the keys available in the object:\n%s", + json_object_to_json_string_ext(item, JSON_C_TO_STRING_PRETTY)); + } for (j = 0; j < n_input_keys; j++) { value = json_object_object_get(item, in_keys[j]); obj_type = json_object_get_type(value); @@ -517,7 +522,7 @@ int main(int argc, char *argv[]) { // use input format if format variable is not defined format = (!format && !strcmp(out_file, "-")) ? file_format_infer(in_file) : format; - file_write(out_file, X, y, in_cols, n_in_cols, out_cols, n_out_cols, true, format); + file_write(out_file, X, y, in_cols, n_in_cols, out_cols, n_out_cols, true, format, -1); for (i = 0; i < n_in_cols; i++) free(in_cols[i]); for (i = 0; i < n_out_cols; i++) free(out_cols[i]); @@ -91,7 +91,7 @@ void usage(int exit_code) { FILE *fp = (!exit_code) ? stdout : stderr; fprintf(fp, - "Usage: ml train [Options] FILE\n" + "Usage: ml [re]train [Options] FILE\n" " or: ml predict [-Ohv] [-f FORMAT] [-o FILE] [-p INT] FILE\n" "\n" "Options:\n" |