diff options
Diffstat (limited to 'src/nn.c')
-rw-r--r-- | src/nn.c | 35 |
1 files changed, 35 insertions, 0 deletions
@@ -2,6 +2,41 @@ static void fill_random_weights(double *weights, double *bias, size_t rows, size_t cols); +void nn_layer_hidden_delta( + double *delta, size_t delta_cols, + double *delta_next, size_t delta_next_cols, + double *weigths_next, size_t weigths_shape[2], + double *zout, size_t zout_cols, + double (*activation_derivative)(double)) +{ + assert(delta_cols == zout_cols); + assert(delta_cols == weigths_shape[0]); + assert(delta_next_cols == weigths_shape[1]); + + for (size_t j = 0; j < delta_cols; j++) { + double sum = 0; + for (size_t k = 0; k < delta_next_cols; k++) { + size_t index = j * delta_cols + k; + sum += delta_next[k] * weigths_next[index]; + } + delta[j] = sum * activation_derivative(zout[j]); + } +} + +void nn_layer_out_delta( + double *delta, size_t delta_cols, + double *error, size_t error_cols, + double *zout, size_t zout_cols, + double (*activation_derivative)(double)) +{ + assert(delta_cols == error_cols); + assert(zout_cols == error_cols); + + for (size_t i = 0; i < delta_cols; i++) { + delta[i] = error[i] * activation_derivative(zout[i]); + } +} + void nn_forward( double **out, double **zout, double *X, size_t X_shape[2], |