From 65926438256c1ed46993e1c8611597af5a9c23f1 Mon Sep 17 00:00:00 2001 From: jvech Date: Wed, 7 Aug 2024 10:06:35 -0500 Subject: add: CLI improvements and small documentation updates Things done: * config path should search config file in the following order: cli option, environment, xdg_path * Implement a retrain command. * when you require more keys than the ones available in the input, stop the program. --- src/main.c | 25 ++++++++++++++++++------- src/parse.c | 7 ++++++- src/util.c | 2 +- 3 files changed, 25 insertions(+), 9 deletions(-) (limited to 'src') diff --git a/src/main.c b/src/main.c index 38f26ad..22737dc 100644 --- a/src/main.c +++ b/src/main.c @@ -51,7 +51,7 @@ void load_config(struct Configs *cfg, int n_args, ...) } else break; } va_end(ap); - die("load_config() Error:"); + die("load_config('%s') Error:", filepath); } Layer * load_network(struct Configs cfg) @@ -88,11 +88,11 @@ struct Cost load_loss(struct Configs cfg) } int main(int argc, char *argv[]) { - char default_config_path[512]; + char default_config_path[512], *env_config_path; struct Configs ml_configs = { .epochs = 100, .alpha = 1e-5, - .config_filepath = "utils/settings.cfg", + .config_filepath = "", .network_size = 0, .only_out = false, .decimal_precision = -1, @@ -103,9 +103,15 @@ int main(int argc, char *argv[]) { // First past to check if --config option was put util_load_cli(&ml_configs, argc, argv); optind = 1; + // Load configs with different possible paths sprintf(default_config_path, "%s/%s", getenv("HOME"), ".config/ml/ml.cfg"); - load_config(&ml_configs, 2, ml_configs.config_filepath, default_config_path); + env_config_path = (getenv("ML_CONFIG_PATH"))? getenv("ML_CONFIG_PATH"):""; + + load_config(&ml_configs, 3, + ml_configs.config_filepath, + env_config_path, + default_config_path); // re-read cli options again, to overwrite file configuration options util_load_cli(&ml_configs, argc, argv); @@ -115,12 +121,17 @@ int main(int argc, char *argv[]) { Layer *network = load_network(ml_configs); Array X, y; - if (!strcmp("train", argv[0])) { + if (!strcmp("train", argv[0]) || !strcmp("retrain", argv[0])) { file_read(argv[1], &X, &y, ml_configs.input_keys, ml_configs.n_input_keys, ml_configs.label_keys, ml_configs.n_label_keys, true, ml_configs.file_format); - nn_network_init_weights(network, ml_configs.network_size, X.shape[1], true); + if (!strcmp("train", argv[0])) { + nn_network_init_weights(network, ml_configs.network_size, X.shape[1], true); + } else if (!strcmp("retrain", argv[0])) { + nn_network_init_weights(network, ml_configs.network_size, X.shape[1], false); + nn_network_read_weights(ml_configs.weights_filepath, network, ml_configs.network_size); + } nn_network_train( network, ml_configs.network_size, X.data, X.shape, @@ -139,7 +150,7 @@ int main(int argc, char *argv[]) { nn_network_read_weights(ml_configs.weights_filepath, network, ml_configs.network_size); nn_network_predict(y.data, y.shape, X.data, X.shape, network, ml_configs.network_size); - // If neither output and file_format defined use input to define the format + // If neither output and file_format defined use input to define the output format if (!ml_configs.file_format && !ml_configs.out_filepath) { ml_configs.file_format = file_format_infer(ml_configs.in_filepath); } diff --git a/src/parse.c b/src/parse.c index cea595b..a06f0f3 100644 --- a/src/parse.c +++ b/src/parse.c @@ -187,6 +187,11 @@ void json_read( die("json_read() Error: unexpected JSON data received, expecting an object"); } + if ((size_t)json_object_object_length(item) < n_input_keys + n_out_keys) { + die("json_read() Error: the number of keys required is greater " + "than the keys available in the object:\n%s", + json_object_to_json_string_ext(item, JSON_C_TO_STRING_PRETTY)); + } for (j = 0; j < n_input_keys; j++) { value = json_object_object_get(item, in_keys[j]); obj_type = json_object_get_type(value); @@ -517,7 +522,7 @@ int main(int argc, char *argv[]) { // use input format if format variable is not defined format = (!format && !strcmp(out_file, "-")) ? file_format_infer(in_file) : format; - file_write(out_file, X, y, in_cols, n_in_cols, out_cols, n_out_cols, true, format); + file_write(out_file, X, y, in_cols, n_in_cols, out_cols, n_out_cols, true, format, -1); for (i = 0; i < n_in_cols; i++) free(in_cols[i]); for (i = 0; i < n_out_cols; i++) free(out_cols[i]); diff --git a/src/util.c b/src/util.c index 4621836..9a00aa3 100644 --- a/src/util.c +++ b/src/util.c @@ -91,7 +91,7 @@ void usage(int exit_code) { FILE *fp = (!exit_code) ? stdout : stderr; fprintf(fp, - "Usage: ml train [Options] FILE\n" + "Usage: ml [re]train [Options] FILE\n" " or: ml predict [-Ohv] [-f FORMAT] [-o FILE] [-p INT] FILE\n" "\n" "Options:\n" -- cgit v1.2.3-70-g09d2