From 179613cd790ddf87c3fc501b346f242d17065917 Mon Sep 17 00:00:00 2001 From: jvech Date: Sat, 5 Aug 2023 20:48:21 -0500 Subject: add: bias backward propagation implemented --- src/nn.c | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) (limited to 'src/nn.c') diff --git a/src/nn.c b/src/nn.c index ca04003..83fd77a 100644 --- a/src/nn.c +++ b/src/nn.c @@ -3,7 +3,7 @@ static void fill_random_weights(double *weights, double *bias, size_t rows, size_t cols); void nn_backward( - double **weights, + double **weights, double **bias, double **Zout, double **Outs, double *Input, size_t input_shape[2], double *Labels, size_t labels_shape[2], @@ -33,19 +33,20 @@ void nn_backward( double *zout = Zout[l] + sample * network[l].neurons; double *out_prev = Outs[l - 1] + sample * network[l-1].neurons; nn_layer_out_delta(delta, dcost_out, zout, network[l].neurons, network[l].activation_derivative); - nn_layer_backward(weights[l], weigths_shape, delta, out_prev, network[l], alpha); + nn_layer_backward(weights[l], bias[l], weigths_shape, delta, out_prev, network[l], alpha); } else if (l == 0) { size_t weigths_next_shape[2] = {network[l+1].input_nodes, network[l+1].neurons}; double *zout = Zout[l] + sample * network[l].neurons; double *input = Input + sample * input_shape[1]; nn_layer_hidden_delta(delta, delta_next, zout, weights[l+1], weigths_next_shape, network[l].activation_derivative); - nn_layer_backward(weights[l], weigths_shape, delta, input, network[l], alpha); + nn_layer_backward(weights[l], bias[l], weigths_shape, delta, input, network[l], alpha); + break; } else { size_t weigths_next_shape[2] = {network[l+1].input_nodes, network[l+1].neurons}; double *zout = Zout[l] + sample * network[l].neurons; double *out_prev = Outs[l - 1] + sample * network[l-1].neurons; nn_layer_hidden_delta(delta, delta_next, zout, weights[l+1], weigths_next_shape, network[l].activation_derivative); - nn_layer_backward(weights[l], weigths_shape, delta, out_prev, network[l], alpha); + nn_layer_backward(weights[l], bias[l], weigths_shape, delta, out_prev, network[l], alpha); } memcpy(delta_next, delta, weigths_shape[1] * sizeof(double)); } @@ -57,7 +58,7 @@ void nn_backward( } void nn_layer_backward( - double *weights, size_t weigths_shape[2], + double *weights, double *bias, size_t weigths_shape[2], double *delta, double *out_prev, Layer layer, double alpha) { @@ -65,9 +66,12 @@ void nn_layer_backward( for (size_t j = 0; j < weigths_shape[1]; j++) { size_t index = weigths_shape[1] * i + j; double dcost_w = delta[j] * out_prev[i]; - weights[index] = layer.weights[index] + alpha * dcost_w; + weights[index] = layer.weights[index] - alpha * dcost_w; } } + + for (size_t j = 0; j < weigths_shape[1]; j++) + bias[j] = layer.bias[j] - alpha * delta[j]; } void nn_layer_hidden_delta( -- cgit v1.2.3-70-g09d2