aboutsummaryrefslogtreecommitdiff
path: root/src/util.c
diff options
context:
space:
mode:
authorjvech <jmvalenciae@unal.edu.co>2024-08-26 12:12:40 -0500
committerjvech <jmvalenciae@unal.edu.co>2024-08-26 12:12:40 -0500
commite8624e1ebcabcc831d651e0beefe32df1463c903 (patch)
treea5fb8f491b30c01bc301cf0e5c7558cb4497b2b2 /src/util.c
parent65926438256c1ed46993e1c8611597af5a9c23f1 (diff)
add: mini batch learning implemented
Diffstat (limited to 'src/util.c')
-rw-r--r--src/util.c13
1 files changed, 10 insertions, 3 deletions
diff --git a/src/util.c b/src/util.c
index 9a00aa3..71cbf8f 100644
--- a/src/util.c
+++ b/src/util.c
@@ -97,11 +97,12 @@ void usage(int exit_code)
"Options:\n"
" -h, --help Show this message\n"
" -f, --format=FORMAT Define input or output FILE format if needed\n"
+ " -O, --only-out Don't show input fields (only works with predict)\n"
" -a, --alpha=ALPHA Learning rate (only works with train)\n"
+ " -b, --batch=INT Select batch size [default: 32] (only works with train)\n"
+ " -c, --config=FILE Configuration filepath [default=~/.config/ml/ml.cfg]\n"
" -e, --epochs=EPOCHS Epochs to train the model (only works with train)\n"
" -o, --output=FILE Output file (only works with predict)\n"
- " -O, --only-out Don't show input fields (only works with predict)\n"
- " -c, --config=FILE Configuration filepath [default=~/.config/ml/ml.cfg]\n"
" -p, --precision=INT Decimals output precision (only works with predict)\n"
" [default=auto]\n"
"\n"
@@ -117,6 +118,7 @@ void util_load_cli(struct Configs *ml, int argc, char *argv[])
{"version", no_argument, 0, 'v'},
{"format", required_argument, 0, 'f'},
{"epochs", required_argument, 0, 'e'},
+ {"batch", required_argument, 0, 'b'},
{"alpha", required_argument, 0, 'a'},
{"output", required_argument, 0, 'o'},
{"config", required_argument, 0, 'c'},
@@ -127,7 +129,7 @@ void util_load_cli(struct Configs *ml, int argc, char *argv[])
int c;
while (1) {
- c = getopt_long(argc, argv, "hvOc:e:a:o:i:f:p:", long_opts, NULL);
+ c = getopt_long(argc, argv, "hvOc:e:a:o:i:f:p:b:", long_opts, NULL);
if (c == -1) {
break;
@@ -154,6 +156,10 @@ void util_load_cli(struct Configs *ml, int argc, char *argv[])
case 'p':
ml->decimal_precision = (!strcmp("auto", optarg))? -1: (int)atoi(optarg);
break;
+ case 'b':
+ if (atoi(optarg) <= 0) die("util_load_cli() Error: batch size must be greater than 0");
+ ml->batch_size = (size_t)atol(optarg);
+ break;
case 'h':
usage(0);
break;
@@ -316,6 +322,7 @@ void load_net_cfgs(struct Configs *cfg, char *key, char *value, char *strtok_ptr
if (!strcmp(key, "weights_path")) cfg->weights_filepath = e_strdup(value);
else if (!strcmp(key, "loss")) cfg->loss = e_strdup(value);
else if (!strcmp(key, "epochs")) cfg->epochs = (size_t)atol(value);
+ else if (!strcmp(key, "batch")) cfg->batch_size = (size_t)atol(value);
else if (!strcmp(key, "alpha")) cfg->alpha = (double)atof(value);
else if (!strcmp(key, "inputs")) cfg->input_keys = config_read_values(&(cfg->n_input_keys), value, &strtok_ptr);
else if (!strcmp(key, "labels")) cfg->label_keys = config_read_values(&(cfg->n_label_keys), value, &strtok_ptr);
Feel free to download, copy and edit any repo