From a63ee8ba14aad0e364e928399b24196c72a4217f Mon Sep 17 00:00:00 2001 From: jvech Date: Wed, 4 Sep 2024 22:15:23 -0500 Subject: fix: ml_network_train() refactored --- src/main.c | 18 +----------------- 1 file changed, 1 insertion(+), 17 deletions(-) (limited to 'src/main.c') diff --git a/src/main.c b/src/main.c index 848d638..63a2fa9 100644 --- a/src/main.c +++ b/src/main.c @@ -79,14 +79,6 @@ Layer * load_network(struct Configs cfg) return network; } -struct Cost load_loss(struct Configs cfg) -{ - extern struct Cost NN_SQUARE; - if (!strcmp("square", cfg.loss)) return NN_SQUARE; - die("load_loss() Error: Unknown '%s' loss function", cfg.loss); - exit(1); -} - int main(int argc, char *argv[]) { char default_config_path[512], *env_config_path; struct Configs ml_configs = { @@ -134,15 +126,7 @@ int main(int argc, char *argv[]) { nn_network_init_weights(network, ml_configs.network_size, X.shape[1], false); nn_network_read_weights(ml_configs.weights_filepath, network, ml_configs.network_size); } - nn_network_train( - network, ml_configs.network_size, - X.data, X.shape, - y.data, y.shape, - load_loss(ml_configs), - ml_configs.epochs, - ml_configs.batch_size, - ml_configs.alpha, - ml_configs.shuffle); + nn_network_train(network, ml_configs, X.data, X.shape, y.data, y.shape); nn_network_write_weights(ml_configs.weights_filepath, network, ml_configs.network_size); fprintf(stderr, "weights saved on '%s'\n", ml_configs.weights_filepath); } else if (!strcmp("predict", argv[0])) { -- cgit v1.2.3-70-g09d2