diff options
author | jvech <jmvalenciae@unal.edu.co> | 2023-08-27 20:43:23 -0500 |
---|---|---|
committer | jvech <jmvalenciae@unal.edu.co> | 2023-08-27 20:43:23 -0500 |
commit | 8daf85f463d159b2b69939233c18760d72b6f4ab (patch) | |
tree | 019641bf7eac5fbbad66d2116a2402f1b008f19f /src | |
parent | 624c04b33afff299121a5ded475070a2f0236cff (diff) |
add: data and network initialization done
train subcommand can read and train the network
TODO:
- Refactor json_read() to parse multiple labels
- Implement a function to save network weights once the network have
trained
- Implement a function to load trained weights to use with predict
subcommand
Diffstat (limited to 'src')
-rw-r--r-- | src/main.c | 48 | ||||
-rw-r--r-- | src/util.h | 1 |
2 files changed, 49 insertions, 0 deletions
@@ -110,6 +110,34 @@ void load_config(struct Configs *cfg, int n_args, ...) die("load_config() Error:"); } +Layer * load_network(struct Configs cfg) +{ + extern struct Activation NN_RELU; + extern struct Activation NN_SOFTPLUS; + extern struct Activation NN_SIGMOID; + extern struct Activation NN_LEAKY_RELU; + + Layer *network = ecalloc(cfg.network_size, sizeof(Layer)); + + for (size_t i = 0; i < cfg.network_size; i++) { + if (!strcmp("relu", cfg.activations[i])) network[i].activation = NN_RELU; + else if (!strcmp("sigmoid", cfg.activations[i])) network[i].activation = NN_SIGMOID; + else if (!strcmp("softplus", cfg.activations[i])) network[i].activation = NN_SOFTPLUS; + else if (!strcmp("leaky_relu", cfg.activations[i])) network[i].activation = NN_LEAKY_RELU; + else die("load_network() Error: Unknown '%s' activation", cfg.activations[i]); + + network[i].neurons = cfg.neurons[i]; + } + return network; +} + +struct Cost load_loss(struct Configs cfg) +{ + extern struct Cost NN_SQUARE; + if (!strcmp("square", cfg.loss)) return NN_SQUARE; + die("load_loss() Error: Unknown '%s' loss function", cfg.loss); + exit(1); +} int main(int argc, char *argv[]) { struct Configs ml_configs = { @@ -122,6 +150,26 @@ int main(int argc, char *argv[]) { // Try different config paths load_config(&ml_configs, 3, "~/.config/ml/ml.cfg", "~/.ml/ml.cfg", ml_configs.config_filepath); util_load_cli(&ml_configs, argc, argv); + argc -= optind; + argv += optind; + Layer *network = load_network(ml_configs); + + Array X, y; + if (!strcmp("train", argv[0])) { + json_read(argv[1], &X, &y, ml_configs.label_keys[0], ml_configs.input_keys, ml_configs.n_input_keys); + nn_network_init_weights(network, ml_configs.network_size, X.shape[1]); + nn_network_train( + network, ml_configs.network_size, + X.data, X.shape, + y.data, y.shape, + load_loss(ml_configs), + ml_configs.epochs, + ml_configs.alpha); + } else if (!strcmp("predict", argv[0])) { + } else usage(1); + + nn_network_free_weights(network, ml_configs.network_size); + free(network); util_free_config(&ml_configs); return 0; } @@ -21,6 +21,7 @@ struct Configs { char **activations; }; +void usage(int exit_code); void die(const char *fmt, ...); void *ecalloc(size_t nmemb, size_t size); void *erealloc(void *ptr, size_t size); |