diff options
Diffstat (limited to 'src/main.c')
-rw-r--r-- | src/main.c | 48 |
1 files changed, 48 insertions, 0 deletions
@@ -110,6 +110,34 @@ void load_config(struct Configs *cfg, int n_args, ...) die("load_config() Error:"); } +Layer * load_network(struct Configs cfg) +{ + extern struct Activation NN_RELU; + extern struct Activation NN_SOFTPLUS; + extern struct Activation NN_SIGMOID; + extern struct Activation NN_LEAKY_RELU; + + Layer *network = ecalloc(cfg.network_size, sizeof(Layer)); + + for (size_t i = 0; i < cfg.network_size; i++) { + if (!strcmp("relu", cfg.activations[i])) network[i].activation = NN_RELU; + else if (!strcmp("sigmoid", cfg.activations[i])) network[i].activation = NN_SIGMOID; + else if (!strcmp("softplus", cfg.activations[i])) network[i].activation = NN_SOFTPLUS; + else if (!strcmp("leaky_relu", cfg.activations[i])) network[i].activation = NN_LEAKY_RELU; + else die("load_network() Error: Unknown '%s' activation", cfg.activations[i]); + + network[i].neurons = cfg.neurons[i]; + } + return network; +} + +struct Cost load_loss(struct Configs cfg) +{ + extern struct Cost NN_SQUARE; + if (!strcmp("square", cfg.loss)) return NN_SQUARE; + die("load_loss() Error: Unknown '%s' loss function", cfg.loss); + exit(1); +} int main(int argc, char *argv[]) { struct Configs ml_configs = { @@ -122,6 +150,26 @@ int main(int argc, char *argv[]) { // Try different config paths load_config(&ml_configs, 3, "~/.config/ml/ml.cfg", "~/.ml/ml.cfg", ml_configs.config_filepath); util_load_cli(&ml_configs, argc, argv); + argc -= optind; + argv += optind; + Layer *network = load_network(ml_configs); + + Array X, y; + if (!strcmp("train", argv[0])) { + json_read(argv[1], &X, &y, ml_configs.label_keys[0], ml_configs.input_keys, ml_configs.n_input_keys); + nn_network_init_weights(network, ml_configs.network_size, X.shape[1]); + nn_network_train( + network, ml_configs.network_size, + X.data, X.shape, + y.data, y.shape, + load_loss(ml_configs), + ml_configs.epochs, + ml_configs.alpha); + } else if (!strcmp("predict", argv[0])) { + } else usage(1); + + nn_network_free_weights(network, ml_configs.network_size); + free(network); util_free_config(&ml_configs); return 0; } |