diff options
Diffstat (limited to 'src/main.c')
-rw-r--r-- | src/main.c | 18 |
1 files changed, 1 insertions, 17 deletions
@@ -79,14 +79,6 @@ Layer * load_network(struct Configs cfg) return network; } -struct Cost load_loss(struct Configs cfg) -{ - extern struct Cost NN_SQUARE; - if (!strcmp("square", cfg.loss)) return NN_SQUARE; - die("load_loss() Error: Unknown '%s' loss function", cfg.loss); - exit(1); -} - int main(int argc, char *argv[]) { char default_config_path[512], *env_config_path; struct Configs ml_configs = { @@ -134,15 +126,7 @@ int main(int argc, char *argv[]) { nn_network_init_weights(network, ml_configs.network_size, X.shape[1], false); nn_network_read_weights(ml_configs.weights_filepath, network, ml_configs.network_size); } - nn_network_train( - network, ml_configs.network_size, - X.data, X.shape, - y.data, y.shape, - load_loss(ml_configs), - ml_configs.epochs, - ml_configs.batch_size, - ml_configs.alpha, - ml_configs.shuffle); + nn_network_train(network, ml_configs, X.data, X.shape, y.data, y.shape); nn_network_write_weights(ml_configs.weights_filepath, network, ml_configs.network_size); fprintf(stderr, "weights saved on '%s'\n", ml_configs.weights_filepath); } else if (!strcmp("predict", argv[0])) { |