aboutsummaryrefslogtreecommitdiff
path: root/src/main.c
diff options
context:
space:
mode:
Diffstat (limited to 'src/main.c')
-rw-r--r--src/main.c25
1 files changed, 18 insertions, 7 deletions
diff --git a/src/main.c b/src/main.c
index 38f26ad..22737dc 100644
--- a/src/main.c
+++ b/src/main.c
@@ -51,7 +51,7 @@ void load_config(struct Configs *cfg, int n_args, ...)
} else break;
}
va_end(ap);
- die("load_config() Error:");
+ die("load_config('%s') Error:", filepath);
}
Layer * load_network(struct Configs cfg)
@@ -88,11 +88,11 @@ struct Cost load_loss(struct Configs cfg)
}
int main(int argc, char *argv[]) {
- char default_config_path[512];
+ char default_config_path[512], *env_config_path;
struct Configs ml_configs = {
.epochs = 100,
.alpha = 1e-5,
- .config_filepath = "utils/settings.cfg",
+ .config_filepath = "",
.network_size = 0,
.only_out = false,
.decimal_precision = -1,
@@ -103,9 +103,15 @@ int main(int argc, char *argv[]) {
// First past to check if --config option was put
util_load_cli(&ml_configs, argc, argv);
optind = 1;
+
// Load configs with different possible paths
sprintf(default_config_path, "%s/%s", getenv("HOME"), ".config/ml/ml.cfg");
- load_config(&ml_configs, 2, ml_configs.config_filepath, default_config_path);
+ env_config_path = (getenv("ML_CONFIG_PATH"))? getenv("ML_CONFIG_PATH"):"";
+
+ load_config(&ml_configs, 3,
+ ml_configs.config_filepath,
+ env_config_path,
+ default_config_path);
// re-read cli options again, to overwrite file configuration options
util_load_cli(&ml_configs, argc, argv);
@@ -115,12 +121,17 @@ int main(int argc, char *argv[]) {
Layer *network = load_network(ml_configs);
Array X, y;
- if (!strcmp("train", argv[0])) {
+ if (!strcmp("train", argv[0]) || !strcmp("retrain", argv[0])) {
file_read(argv[1], &X, &y,
ml_configs.input_keys, ml_configs.n_input_keys,
ml_configs.label_keys, ml_configs.n_label_keys,
true, ml_configs.file_format);
- nn_network_init_weights(network, ml_configs.network_size, X.shape[1], true);
+ if (!strcmp("train", argv[0])) {
+ nn_network_init_weights(network, ml_configs.network_size, X.shape[1], true);
+ } else if (!strcmp("retrain", argv[0])) {
+ nn_network_init_weights(network, ml_configs.network_size, X.shape[1], false);
+ nn_network_read_weights(ml_configs.weights_filepath, network, ml_configs.network_size);
+ }
nn_network_train(
network, ml_configs.network_size,
X.data, X.shape,
@@ -139,7 +150,7 @@ int main(int argc, char *argv[]) {
nn_network_read_weights(ml_configs.weights_filepath, network, ml_configs.network_size);
nn_network_predict(y.data, y.shape, X.data, X.shape, network, ml_configs.network_size);
- // If neither output and file_format defined use input to define the format
+ // If neither output and file_format defined use input to define the output format
if (!ml_configs.file_format && !ml_configs.out_filepath) {
ml_configs.file_format = file_format_infer(ml_configs.in_filepath);
}
Feel free to download, copy and edit any repo