aboutsummaryrefslogtreecommitdiff
path: root/src/main.c
diff options
context:
space:
mode:
authorjvech <jmvalenciae@unal.edu.co>2024-09-04 22:15:23 -0500
committerjvech <jmvalenciae@unal.edu.co>2024-09-04 22:15:23 -0500
commita63ee8ba14aad0e364e928399b24196c72a4217f (patch)
treee9561f3ce87bcca5ea0fc8dcf002fa6024723926 /src/main.c
parentf39f6d5b0a907d519377e70876b32daad1a676f2 (diff)
fix: ml_network_train() refactored
Diffstat (limited to 'src/main.c')
-rw-r--r--src/main.c18
1 files changed, 1 insertions, 17 deletions
diff --git a/src/main.c b/src/main.c
index 848d638..63a2fa9 100644
--- a/src/main.c
+++ b/src/main.c
@@ -79,14 +79,6 @@ Layer * load_network(struct Configs cfg)
return network;
}
-struct Cost load_loss(struct Configs cfg)
-{
- extern struct Cost NN_SQUARE;
- if (!strcmp("square", cfg.loss)) return NN_SQUARE;
- die("load_loss() Error: Unknown '%s' loss function", cfg.loss);
- exit(1);
-}
-
int main(int argc, char *argv[]) {
char default_config_path[512], *env_config_path;
struct Configs ml_configs = {
@@ -134,15 +126,7 @@ int main(int argc, char *argv[]) {
nn_network_init_weights(network, ml_configs.network_size, X.shape[1], false);
nn_network_read_weights(ml_configs.weights_filepath, network, ml_configs.network_size);
}
- nn_network_train(
- network, ml_configs.network_size,
- X.data, X.shape,
- y.data, y.shape,
- load_loss(ml_configs),
- ml_configs.epochs,
- ml_configs.batch_size,
- ml_configs.alpha,
- ml_configs.shuffle);
+ nn_network_train(network, ml_configs, X.data, X.shape, y.data, y.shape);
nn_network_write_weights(ml_configs.weights_filepath, network, ml_configs.network_size);
fprintf(stderr, "weights saved on '%s'\n", ml_configs.weights_filepath);
} else if (!strcmp("predict", argv[0])) {
Feel free to download, copy and edit any repo