aboutsummaryrefslogtreecommitdiff
path: root/src/nn.c
diff options
context:
space:
mode:
authorjvech <jmvalenciae@unal.edu.co>2023-07-28 09:29:18 -0500
committerjvech <jmvalenciae@unal.edu.co>2023-07-28 09:29:18 -0500
commit1503fc83991237fa0cf6eb42b0ca1a4904cf8a01 (patch)
treebea34a7dd81b469085bbc8cb7a403a92cf46c974 /src/nn.c
parente9b26e6cae80a089f6b969226a968f5b79a7820b (diff)
add: network forward pass implemented
Diffstat (limited to 'src/nn.c')
-rw-r--r--src/nn.c18
1 files changed, 18 insertions, 0 deletions
diff --git a/src/nn.c b/src/nn.c
index d773324..e47b56d 100644
--- a/src/nn.c
+++ b/src/nn.c
@@ -2,6 +2,24 @@
static void fill_random_weights(double *weights, double *bias, size_t rows, size_t cols);
+void nn_forward(
+ double **out,
+ double *X, size_t X_shape[2],
+ Layer network[], size_t network_size)
+{
+ size_t in_shape[2] = {X_shape[0], X_shape[1]};
+ size_t out_shape[2];
+ out_shape[0] = X_shape[0];
+ double *input = X;
+
+ for (size_t l = 0; l < network_size; l++) {
+ out_shape[1] = network[l].neurons;
+ nn_layer_forward(network[l], out[l], out_shape, input, in_shape);
+ in_shape[1] = out_shape[1];
+ input = out[l];
+ }
+}
+
void nn_layer_forward(Layer layer, double *out, size_t out_shape[2], double *input, size_t input_shape[2])
{
if (out_shape[0] != input_shape[0] || out_shape[1] != layer.neurons) {
Feel free to download, copy and edit any repo