From 65926438256c1ed46993e1c8611597af5a9c23f1 Mon Sep 17 00:00:00 2001 From: jvech Date: Wed, 7 Aug 2024 10:06:35 -0500 Subject: add: CLI improvements and small documentation updates Things done: * config path should search config file in the following order: cli option, environment, xdg_path * Implement a retrain command. * when you require more keys than the ones available in the input, stop the program. --- src/main.c | 25 ++++++++++++++++++------- 1 file changed, 18 insertions(+), 7 deletions(-) (limited to 'src/main.c') diff --git a/src/main.c b/src/main.c index 38f26ad..22737dc 100644 --- a/src/main.c +++ b/src/main.c @@ -51,7 +51,7 @@ void load_config(struct Configs *cfg, int n_args, ...) } else break; } va_end(ap); - die("load_config() Error:"); + die("load_config('%s') Error:", filepath); } Layer * load_network(struct Configs cfg) @@ -88,11 +88,11 @@ struct Cost load_loss(struct Configs cfg) } int main(int argc, char *argv[]) { - char default_config_path[512]; + char default_config_path[512], *env_config_path; struct Configs ml_configs = { .epochs = 100, .alpha = 1e-5, - .config_filepath = "utils/settings.cfg", + .config_filepath = "", .network_size = 0, .only_out = false, .decimal_precision = -1, @@ -103,9 +103,15 @@ int main(int argc, char *argv[]) { // First past to check if --config option was put util_load_cli(&ml_configs, argc, argv); optind = 1; + // Load configs with different possible paths sprintf(default_config_path, "%s/%s", getenv("HOME"), ".config/ml/ml.cfg"); - load_config(&ml_configs, 2, ml_configs.config_filepath, default_config_path); + env_config_path = (getenv("ML_CONFIG_PATH"))? getenv("ML_CONFIG_PATH"):""; + + load_config(&ml_configs, 3, + ml_configs.config_filepath, + env_config_path, + default_config_path); // re-read cli options again, to overwrite file configuration options util_load_cli(&ml_configs, argc, argv); @@ -115,12 +121,17 @@ int main(int argc, char *argv[]) { Layer *network = load_network(ml_configs); Array X, y; - if (!strcmp("train", argv[0])) { + if (!strcmp("train", argv[0]) || !strcmp("retrain", argv[0])) { file_read(argv[1], &X, &y, ml_configs.input_keys, ml_configs.n_input_keys, ml_configs.label_keys, ml_configs.n_label_keys, true, ml_configs.file_format); - nn_network_init_weights(network, ml_configs.network_size, X.shape[1], true); + if (!strcmp("train", argv[0])) { + nn_network_init_weights(network, ml_configs.network_size, X.shape[1], true); + } else if (!strcmp("retrain", argv[0])) { + nn_network_init_weights(network, ml_configs.network_size, X.shape[1], false); + nn_network_read_weights(ml_configs.weights_filepath, network, ml_configs.network_size); + } nn_network_train( network, ml_configs.network_size, X.data, X.shape, @@ -139,7 +150,7 @@ int main(int argc, char *argv[]) { nn_network_read_weights(ml_configs.weights_filepath, network, ml_configs.network_size); nn_network_predict(y.data, y.shape, X.data, X.shape, network, ml_configs.network_size); - // If neither output and file_format defined use input to define the format + // If neither output and file_format defined use input to define the output format if (!ml_configs.file_format && !ml_configs.out_filepath) { ml_configs.file_format = file_format_infer(ml_configs.in_filepath); } -- cgit v1.2.3-70-g09d2