aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorjvech <jmvalenciae@unal.edu.co>2023-08-27 20:43:23 -0500
committerjvech <jmvalenciae@unal.edu.co>2023-08-27 20:43:23 -0500
commit8daf85f463d159b2b69939233c18760d72b6f4ab (patch)
tree019641bf7eac5fbbad66d2116a2402f1b008f19f /src
parent624c04b33afff299121a5ded475070a2f0236cff (diff)
add: data and network initialization done
train subcommand can read and train the network TODO: - Refactor json_read() to parse multiple labels - Implement a function to save network weights once the network have trained - Implement a function to load trained weights to use with predict subcommand
Diffstat (limited to 'src')
-rw-r--r--src/main.c48
-rw-r--r--src/util.h1
2 files changed, 49 insertions, 0 deletions
diff --git a/src/main.c b/src/main.c
index 8f68d1e..30aca1e 100644
--- a/src/main.c
+++ b/src/main.c
@@ -110,6 +110,34 @@ void load_config(struct Configs *cfg, int n_args, ...)
die("load_config() Error:");
}
+Layer * load_network(struct Configs cfg)
+{
+ extern struct Activation NN_RELU;
+ extern struct Activation NN_SOFTPLUS;
+ extern struct Activation NN_SIGMOID;
+ extern struct Activation NN_LEAKY_RELU;
+
+ Layer *network = ecalloc(cfg.network_size, sizeof(Layer));
+
+ for (size_t i = 0; i < cfg.network_size; i++) {
+ if (!strcmp("relu", cfg.activations[i])) network[i].activation = NN_RELU;
+ else if (!strcmp("sigmoid", cfg.activations[i])) network[i].activation = NN_SIGMOID;
+ else if (!strcmp("softplus", cfg.activations[i])) network[i].activation = NN_SOFTPLUS;
+ else if (!strcmp("leaky_relu", cfg.activations[i])) network[i].activation = NN_LEAKY_RELU;
+ else die("load_network() Error: Unknown '%s' activation", cfg.activations[i]);
+
+ network[i].neurons = cfg.neurons[i];
+ }
+ return network;
+}
+
+struct Cost load_loss(struct Configs cfg)
+{
+ extern struct Cost NN_SQUARE;
+ if (!strcmp("square", cfg.loss)) return NN_SQUARE;
+ die("load_loss() Error: Unknown '%s' loss function", cfg.loss);
+ exit(1);
+}
int main(int argc, char *argv[]) {
struct Configs ml_configs = {
@@ -122,6 +150,26 @@ int main(int argc, char *argv[]) {
// Try different config paths
load_config(&ml_configs, 3, "~/.config/ml/ml.cfg", "~/.ml/ml.cfg", ml_configs.config_filepath);
util_load_cli(&ml_configs, argc, argv);
+ argc -= optind;
+ argv += optind;
+ Layer *network = load_network(ml_configs);
+
+ Array X, y;
+ if (!strcmp("train", argv[0])) {
+ json_read(argv[1], &X, &y, ml_configs.label_keys[0], ml_configs.input_keys, ml_configs.n_input_keys);
+ nn_network_init_weights(network, ml_configs.network_size, X.shape[1]);
+ nn_network_train(
+ network, ml_configs.network_size,
+ X.data, X.shape,
+ y.data, y.shape,
+ load_loss(ml_configs),
+ ml_configs.epochs,
+ ml_configs.alpha);
+ } else if (!strcmp("predict", argv[0])) {
+ } else usage(1);
+
+ nn_network_free_weights(network, ml_configs.network_size);
+ free(network);
util_free_config(&ml_configs);
return 0;
}
diff --git a/src/util.h b/src/util.h
index 00567a8..3219502 100644
--- a/src/util.h
+++ b/src/util.h
@@ -21,6 +21,7 @@ struct Configs {
char **activations;
};
+void usage(int exit_code);
void die(const char *fmt, ...);
void *ecalloc(size_t nmemb, size_t size);
void *erealloc(void *ptr, size_t size);
Feel free to download, copy and edit any repo