aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/main.c88
-rw-r--r--src/util.c98
-rw-r--r--src/util.h22
3 files changed, 119 insertions, 89 deletions
diff --git a/src/main.c b/src/main.c
index 683447f..0071426 100644
--- a/src/main.c
+++ b/src/main.c
@@ -16,33 +16,6 @@ typedef struct Array {
#define ARRAY_SIZE(x, type) sizeof(x) / sizeof(type)
-static void version()
-{
- printf("ml 0.1\n");
- printf("Written by vech\n");
- exit(0);
-}
-
-static void usage(int exit_code)
-{
- FILE *fp = (!exit_code) ? stdout : stderr;
- fprintf(fp,
- "Usage: ml train [Options] JSON_FILE\n"
- " or: ml predict [-o FILE] FILE\n"
- "Train and predict json data\n"
- "\n"
- "Options:\n"
- " -a, --alpha=ALPHA Learning rate (only works with train) [default: 1e-5]\n"
- " -e, --epochs=EPOCHS Number of epochs to train the model (only works with train)\n"
- " [default: 100]\n"
- " -o, --output FILE Output file (only works with predict)\n"
- "\n"
- "Examples:\n"
- " $ ml train -e 150 -a 1e-4 housing.json\n"
- " $ ml predict housing.json -o predictions.json\n"
- );
- exit(exit_code);
-}
static void json_read(const char *filepath,
Array *input, Array *out,
@@ -116,61 +89,12 @@ json_read_error:
}
int main(int argc, char *argv[]) {
-
- char **input_keys, **label_keys;
- int in_key_size = 0, label_key_size = 0;
- char *out_filename = NULL, *in_filename = NULL;
- size_t epochs = 100;
- double alpha = 1e-5;
-
- if (argc <= 1) usage(1);
- static struct option long_opts[] = {
- {"help", no_argument, 0, 'h'},
- {"version", no_argument, 0, 'v'},
- {"epochs", required_argument, 0, 'e'},
- {"alpha", required_argument, 0, 'a'},
- {"output", required_argument, 0, 'o'},
- {0, 0, 0, 0 },
+ struct Configs ml_configs = {
+ .epochs = 100,
+ .alpha = 1e-5,
+ .config_filepath = "utils/config.yml",
};
- int c;
-
- while (1) {
- c = getopt_long(argc, argv, "hve:a:o:i:l:", long_opts, NULL);
-
- if (c == -1) {
- break;
- }
- switch (c) {
- case 'e':
- epochs = (size_t)atol(optarg);
- break;
- case 'a':
- alpha = (double)atof(optarg);
- break;
- case 'o':
- out_filename = optarg;
- break;
- case 'h':
- usage(0);
- case 'v':
- version();
- default:
- usage(1);
- }
- }
-
- argv += optind;
- argc -= optind;
- if (argc != 2) usage(1);
-
- in_filename = argv[1];
-
- if (!strcmp(argv[0], "train")) {
- printf("train command\n");
- printf("in_filename: '%s'\n", in_filename);
- } else if (!strcmp(argv[0], "predict")) {
- } else {
- usage(1);
- }
+ util_load_config(&ml_configs);
+ util_load_cli(&ml_configs, argc, argv);
return 0;
}
diff --git a/src/util.c b/src/util.c
index 4a8c3ef..9c1dba3 100644
--- a/src/util.c
+++ b/src/util.c
@@ -1,10 +1,12 @@
/* See LICENSE file for copyright and license details. */
#include <string.h>
+#include <stdio.h>
+#include <stdlib.h>
+#include <getopt.h>
#include "util.h"
-void
-die(const char *fmt, ...)
+void die(const char *fmt, ...)
{
va_list ap;
@@ -22,8 +24,7 @@ die(const char *fmt, ...)
exit(1);
}
-void *
-ecalloc(size_t nmemb, size_t size)
+void * ecalloc(size_t nmemb, size_t size)
{
void *p;
@@ -31,3 +32,92 @@ ecalloc(size_t nmemb, size_t size)
die("calloc:");
return p;
}
+
+
+void version()
+{
+ printf("ml 0.1\n");
+ printf("Written by vech\n");
+ exit(0);
+}
+
+void usage(int exit_code)
+{
+ FILE *fp = (!exit_code) ? stdout : stderr;
+ fprintf(fp,
+ "Usage: ml train [Options] JSON_FILE\n"
+ " or: ml predict [-o FILE] FILE\n"
+ "Train and predict json data\n"
+ "\n"
+ "Options:\n"
+ " -a, --alpha=ALPHA Learning rate (only works with train) [default: 1e-5]\n"
+ " -e, --epochs=EPOCHS Number of epochs to train the model (only works with train)\n"
+ " [default: 100]\n"
+ " -o, --output FILE Output file (only works with predict)\n"
+ "\n"
+ "Examples:\n"
+ " $ ml train -e 150 -a 1e-4 housing.json\n"
+ " $ ml predict housing.json -o predictions.json\n"
+ );
+ exit(exit_code);
+}
+
+void util_load_cli(struct Configs *ml, int argc, char *argv[])
+{
+ if (argc <= 1) usage(1);
+ static struct option long_opts[] = {
+ {"help", no_argument, 0, 'h'},
+ {"version", no_argument, 0, 'v'},
+ {"epochs", required_argument, 0, 'e'},
+ {"alpha", required_argument, 0, 'a'},
+ {"output", required_argument, 0, 'o'},
+ {0, 0, 0, 0 },
+ };
+ int c;
+
+ while (1) {
+ c = getopt_long(argc, argv, "hve:a:o:i:l:", long_opts, NULL);
+
+ if (c == -1) {
+ break;
+ }
+ switch (c) {
+ case 'e':
+ ml->epochs = (size_t)atol(optarg);
+ break;
+ case 'a':
+ ml->alpha = (double)atof(optarg);
+ break;
+ case 'o':
+ ml->out_filepath = optarg;
+ break;
+ case 'h':
+ usage(0);
+ case 'v':
+ version();
+ default:
+ usage(1);
+ }
+ }
+
+ argv += optind;
+ argc -= optind;
+ if (argc != 2) usage(1);
+
+ ml->in_filepath = argv[1];
+}
+
+void util_free_config(struct Configs *ml)
+{
+ if (ml->input_keys != NULL) {
+ for (size_t i = 0; i < ml->n_input_keys; i++)
+ free(ml->input_keys[i]);
+ free(ml->input_keys);
+ }
+
+ if (ml->label_keys != NULL) {
+ for (size_t i = 0; i < ml->n_label_keys; i++)
+ free(ml->label_keys[i]);
+ free(ml->label_keys);
+ }
+}
diff --git a/src/util.h b/src/util.h
index 4068691..e1bcf18 100644
--- a/src/util.h
+++ b/src/util.h
@@ -1,6 +1,22 @@
-#include <stdarg.h>
-#include <stdio.h>
-#include <stdlib.h>
+#ifndef UTIL_
+#define UTIL_
+
+#include <stddef.h>
+
+struct Configs {
+ size_t epochs;
+ double alpha;
+ char **input_keys, **label_keys;
+ size_t n_input_keys, n_label_keys;
+ char *in_filepath;
+ char *out_filepath;
+ char *weights_filepath;
+ char *config_filepath;
+};
void die(const char *fmt, ...);
void *ecalloc(size_t nmemb, size_t size);
+void util_load_cli(struct Configs *ml, int argc, char *argv[]);
+void util_load_config(struct Configs *ml);
+void util_free_config(struct Configs *ml);
+#endif
Feel free to download, copy and edit any repo