diff options
Diffstat (limited to 'src/util.c')
-rw-r--r-- | src/util.c | 13 |
1 files changed, 10 insertions, 3 deletions
@@ -97,11 +97,12 @@ void usage(int exit_code) "Options:\n" " -h, --help Show this message\n" " -f, --format=FORMAT Define input or output FILE format if needed\n" + " -O, --only-out Don't show input fields (only works with predict)\n" " -a, --alpha=ALPHA Learning rate (only works with train)\n" + " -b, --batch=INT Select batch size [default: 32] (only works with train)\n" + " -c, --config=FILE Configuration filepath [default=~/.config/ml/ml.cfg]\n" " -e, --epochs=EPOCHS Epochs to train the model (only works with train)\n" " -o, --output=FILE Output file (only works with predict)\n" - " -O, --only-out Don't show input fields (only works with predict)\n" - " -c, --config=FILE Configuration filepath [default=~/.config/ml/ml.cfg]\n" " -p, --precision=INT Decimals output precision (only works with predict)\n" " [default=auto]\n" "\n" @@ -117,6 +118,7 @@ void util_load_cli(struct Configs *ml, int argc, char *argv[]) {"version", no_argument, 0, 'v'}, {"format", required_argument, 0, 'f'}, {"epochs", required_argument, 0, 'e'}, + {"batch", required_argument, 0, 'b'}, {"alpha", required_argument, 0, 'a'}, {"output", required_argument, 0, 'o'}, {"config", required_argument, 0, 'c'}, @@ -127,7 +129,7 @@ void util_load_cli(struct Configs *ml, int argc, char *argv[]) int c; while (1) { - c = getopt_long(argc, argv, "hvOc:e:a:o:i:f:p:", long_opts, NULL); + c = getopt_long(argc, argv, "hvOc:e:a:o:i:f:p:b:", long_opts, NULL); if (c == -1) { break; @@ -154,6 +156,10 @@ void util_load_cli(struct Configs *ml, int argc, char *argv[]) case 'p': ml->decimal_precision = (!strcmp("auto", optarg))? -1: (int)atoi(optarg); break; + case 'b': + if (atoi(optarg) <= 0) die("util_load_cli() Error: batch size must be greater than 0"); + ml->batch_size = (size_t)atol(optarg); + break; case 'h': usage(0); break; @@ -316,6 +322,7 @@ void load_net_cfgs(struct Configs *cfg, char *key, char *value, char *strtok_ptr if (!strcmp(key, "weights_path")) cfg->weights_filepath = e_strdup(value); else if (!strcmp(key, "loss")) cfg->loss = e_strdup(value); else if (!strcmp(key, "epochs")) cfg->epochs = (size_t)atol(value); + else if (!strcmp(key, "batch")) cfg->batch_size = (size_t)atol(value); else if (!strcmp(key, "alpha")) cfg->alpha = (double)atof(value); else if (!strcmp(key, "inputs")) cfg->input_keys = config_read_values(&(cfg->n_input_keys), value, &strtok_ptr); else if (!strcmp(key, "labels")) cfg->label_keys = config_read_values(&(cfg->n_label_keys), value, &strtok_ptr); |