aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/main.c25
-rw-r--r--src/parse.c7
-rw-r--r--src/util.c2
3 files changed, 25 insertions, 9 deletions
diff --git a/src/main.c b/src/main.c
index 38f26ad..22737dc 100644
--- a/src/main.c
+++ b/src/main.c
@@ -51,7 +51,7 @@ void load_config(struct Configs *cfg, int n_args, ...)
} else break;
}
va_end(ap);
- die("load_config() Error:");
+ die("load_config('%s') Error:", filepath);
}
Layer * load_network(struct Configs cfg)
@@ -88,11 +88,11 @@ struct Cost load_loss(struct Configs cfg)
}
int main(int argc, char *argv[]) {
- char default_config_path[512];
+ char default_config_path[512], *env_config_path;
struct Configs ml_configs = {
.epochs = 100,
.alpha = 1e-5,
- .config_filepath = "utils/settings.cfg",
+ .config_filepath = "",
.network_size = 0,
.only_out = false,
.decimal_precision = -1,
@@ -103,9 +103,15 @@ int main(int argc, char *argv[]) {
// First past to check if --config option was put
util_load_cli(&ml_configs, argc, argv);
optind = 1;
+
// Load configs with different possible paths
sprintf(default_config_path, "%s/%s", getenv("HOME"), ".config/ml/ml.cfg");
- load_config(&ml_configs, 2, ml_configs.config_filepath, default_config_path);
+ env_config_path = (getenv("ML_CONFIG_PATH"))? getenv("ML_CONFIG_PATH"):"";
+
+ load_config(&ml_configs, 3,
+ ml_configs.config_filepath,
+ env_config_path,
+ default_config_path);
// re-read cli options again, to overwrite file configuration options
util_load_cli(&ml_configs, argc, argv);
@@ -115,12 +121,17 @@ int main(int argc, char *argv[]) {
Layer *network = load_network(ml_configs);
Array X, y;
- if (!strcmp("train", argv[0])) {
+ if (!strcmp("train", argv[0]) || !strcmp("retrain", argv[0])) {
file_read(argv[1], &X, &y,
ml_configs.input_keys, ml_configs.n_input_keys,
ml_configs.label_keys, ml_configs.n_label_keys,
true, ml_configs.file_format);
- nn_network_init_weights(network, ml_configs.network_size, X.shape[1], true);
+ if (!strcmp("train", argv[0])) {
+ nn_network_init_weights(network, ml_configs.network_size, X.shape[1], true);
+ } else if (!strcmp("retrain", argv[0])) {
+ nn_network_init_weights(network, ml_configs.network_size, X.shape[1], false);
+ nn_network_read_weights(ml_configs.weights_filepath, network, ml_configs.network_size);
+ }
nn_network_train(
network, ml_configs.network_size,
X.data, X.shape,
@@ -139,7 +150,7 @@ int main(int argc, char *argv[]) {
nn_network_read_weights(ml_configs.weights_filepath, network, ml_configs.network_size);
nn_network_predict(y.data, y.shape, X.data, X.shape, network, ml_configs.network_size);
- // If neither output and file_format defined use input to define the format
+ // If neither output and file_format defined use input to define the output format
if (!ml_configs.file_format && !ml_configs.out_filepath) {
ml_configs.file_format = file_format_infer(ml_configs.in_filepath);
}
diff --git a/src/parse.c b/src/parse.c
index cea595b..a06f0f3 100644
--- a/src/parse.c
+++ b/src/parse.c
@@ -187,6 +187,11 @@ void json_read(
die("json_read() Error: unexpected JSON data received, expecting an object");
}
+ if ((size_t)json_object_object_length(item) < n_input_keys + n_out_keys) {
+ die("json_read() Error: the number of keys required is greater "
+ "than the keys available in the object:\n%s",
+ json_object_to_json_string_ext(item, JSON_C_TO_STRING_PRETTY));
+ }
for (j = 0; j < n_input_keys; j++) {
value = json_object_object_get(item, in_keys[j]);
obj_type = json_object_get_type(value);
@@ -517,7 +522,7 @@ int main(int argc, char *argv[]) {
// use input format if format variable is not defined
format = (!format && !strcmp(out_file, "-")) ? file_format_infer(in_file) : format;
- file_write(out_file, X, y, in_cols, n_in_cols, out_cols, n_out_cols, true, format);
+ file_write(out_file, X, y, in_cols, n_in_cols, out_cols, n_out_cols, true, format, -1);
for (i = 0; i < n_in_cols; i++) free(in_cols[i]);
for (i = 0; i < n_out_cols; i++) free(out_cols[i]);
diff --git a/src/util.c b/src/util.c
index 4621836..9a00aa3 100644
--- a/src/util.c
+++ b/src/util.c
@@ -91,7 +91,7 @@ void usage(int exit_code)
{
FILE *fp = (!exit_code) ? stdout : stderr;
fprintf(fp,
- "Usage: ml train [Options] FILE\n"
+ "Usage: ml [re]train [Options] FILE\n"
" or: ml predict [-Ohv] [-f FORMAT] [-o FILE] [-p INT] FILE\n"
"\n"
"Options:\n"
Feel free to download, copy and edit any repo