aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorjvech <jmvalenciae@unal.edu.co>2024-09-04 22:15:23 -0500
committerjvech <jmvalenciae@unal.edu.co>2024-09-04 22:15:23 -0500
commita63ee8ba14aad0e364e928399b24196c72a4217f (patch)
treee9561f3ce87bcca5ea0fc8dcf002fa6024723926 /src
parentf39f6d5b0a907d519377e70876b32daad1a676f2 (diff)
fix: ml_network_train() refactored
Diffstat (limited to 'src')
-rw-r--r--src/main.c18
-rw-r--r--src/nn.c31
-rw-r--r--src/nn.h11
3 files changed, 28 insertions, 32 deletions
diff --git a/src/main.c b/src/main.c
index 848d638..63a2fa9 100644
--- a/src/main.c
+++ b/src/main.c
@@ -79,14 +79,6 @@ Layer * load_network(struct Configs cfg)
return network;
}
-struct Cost load_loss(struct Configs cfg)
-{
- extern struct Cost NN_SQUARE;
- if (!strcmp("square", cfg.loss)) return NN_SQUARE;
- die("load_loss() Error: Unknown '%s' loss function", cfg.loss);
- exit(1);
-}
-
int main(int argc, char *argv[]) {
char default_config_path[512], *env_config_path;
struct Configs ml_configs = {
@@ -134,15 +126,7 @@ int main(int argc, char *argv[]) {
nn_network_init_weights(network, ml_configs.network_size, X.shape[1], false);
nn_network_read_weights(ml_configs.weights_filepath, network, ml_configs.network_size);
}
- nn_network_train(
- network, ml_configs.network_size,
- X.data, X.shape,
- y.data, y.shape,
- load_loss(ml_configs),
- ml_configs.epochs,
- ml_configs.batch_size,
- ml_configs.alpha,
- ml_configs.shuffle);
+ nn_network_train(network, ml_configs, X.data, X.shape, y.data, y.shape);
nn_network_write_weights(ml_configs.weights_filepath, network, ml_configs.network_size);
fprintf(stderr, "weights saved on '%s'\n", ml_configs.weights_filepath);
} else if (!strcmp("predict", argv[0])) {
diff --git a/src/nn.c b/src/nn.c
index 867819c..19076bb 100644
--- a/src/nn.c
+++ b/src/nn.c
@@ -29,6 +29,8 @@
#include "util.h"
#include "nn.h"
+
+struct Cost load_loss(struct Configs cfg);
static void dataset_shuffle_rows(
double *inputs, size_t in_shape[2],
double *labels, size_t lbl_shape[2]);
@@ -73,15 +75,19 @@ void nn_network_predict(
}
void nn_network_train(
- Layer network[], size_t network_size,
+ Layer network[], struct Configs ml_configs,
double *input, size_t input_shape[2],
- double *labels, size_t labels_shape[2],
- struct Cost cost, size_t epochs,
- size_t batch_size, double alpha,
- bool shuffle)
+ double *labels, size_t labels_shape[2])
{
assert(input_shape[0] == labels_shape[0] && "label samples don't correspond with input samples\n");
+ size_t epochs = ml_configs.epochs;
+ size_t batch_size = ml_configs.batch_size;
+ size_t network_size = ml_configs.network_size;
+ double alpha = ml_configs.alpha;
+ bool shuffle = ml_configs.shuffle;
+ struct Cost cost = load_loss(ml_configs);
+
double **outs = calloc(network_size, sizeof(double *));
double **zouts = calloc(network_size, sizeof(double *));
double **weights = calloc(network_size, sizeof(double *));
@@ -204,19 +210,19 @@ void nn_backward(
double *out_prev = Outs[l - 1] + sample * network[l-1].neurons;
double *dcost_out = dcost_outs + sample * network[l].neurons;
nn_layer_out_delta(delta, dcost_out, zout, network[l].neurons, network[l].activation.dfunc);
- nn_layer_backward(weights[l], bias[l], weights_shape, delta, out_prev, network[l], alpha);
+ nn_layer_backward(weights[l], bias[l], weights_shape, delta, out_prev, alpha);
} else if (l == 0) {
size_t weights_next_shape[2] = {network[l+1].input_nodes, network[l+1].neurons};
double *zout = Zout[l] + sample * network[l].neurons;
double *input = Input + sample * input_shape[1];
nn_layer_hidden_delta(delta, delta_next, zout, weights[l+1], weights_next_shape, network[l].activation.dfunc);
- nn_layer_backward(weights[l], bias[l], weights_shape, delta, input, network[l], alpha);
+ nn_layer_backward(weights[l], bias[l], weights_shape, delta, input, alpha);
} else {
size_t weights_next_shape[2] = {network[l+1].input_nodes, network[l+1].neurons};
double *zout = Zout[l] + sample * network[l].neurons;
double *out_prev = Outs[l - 1] + sample * network[l-1].neurons;
nn_layer_hidden_delta(delta, delta_next, zout, weights[l+1], weights_next_shape, network[l].activation.dfunc);
- nn_layer_backward(weights[l], bias[l], weights_shape, delta, out_prev, network[l], alpha);
+ nn_layer_backward(weights[l], bias[l], weights_shape, delta, out_prev, alpha);
}
memmove(delta_next, delta, weights_shape[1] * sizeof(double));
}
@@ -241,7 +247,7 @@ nn_backward_error:
void nn_layer_backward(
double *weights, double *bias, size_t weights_shape[2],
double *delta, double *out_prev,
- Layer layer, double alpha)
+ double alpha)
{
// W_next = W - alpha * out_prev @ delta.T
cblas_dger(CblasRowMajor, weights_shape[0], weights_shape[1], -alpha,
@@ -544,6 +550,13 @@ double get_avg_loss(
return sum / shape[0];
}
+struct Cost load_loss(struct Configs cfg)
+{
+ if (!strcmp("square", cfg.loss)) return NN_SQUARE;
+ die("load_loss() Error: Unknown '%s' loss function", cfg.loss);
+ exit(1);
+}
+
#ifdef NN_TEST
/*
* compile: clang -Wall -Wextra -g -DNN_TEST -o objs/test_nn src/util.c src/nn.c $(pkg-config --libs-only-l blas) -lm
diff --git a/src/nn.h b/src/nn.h
index 9f8e2a5..105fae7 100644
--- a/src/nn.h
+++ b/src/nn.h
@@ -22,6 +22,8 @@
#include <stdbool.h>
#include <stddef.h>
+#include "util.h"
+
struct Cost {
double (*func)(double labels[], double net_out[], size_t shape);
double (*dfunc_out)(double labels, double net_out);
@@ -49,12 +51,9 @@ void nn_network_predict(
Layer network[], size_t network_size);
void nn_network_train(
- Layer network[], size_t network_size,
+ Layer network[], struct Configs ml_configs,
double *input, size_t input_shape[2],
- double *labels, size_t labels_shape[2],
- struct Cost cost, size_t epochs,
- size_t batch_size, double alpha,
- bool shuffle);
+ double *labels, size_t labels_shape[2]);
void nn_layer_map_activation(
double (*activation)(double),
@@ -89,7 +88,7 @@ void nn_layer_forward(
void nn_layer_backward(
double *weights, double *bias, size_t weigths_shape[2],
double *delta, double *out_prev,
- Layer layer, double alpha);
+ double alpha);
void nn_layer_out_delta(
double *delta, double *dcost_out, double *zout, size_t cols,
Feel free to download, copy and edit any repo