aboutsummaryrefslogtreecommitdiff
path: root/src/nn.h
blob: 94772c20e71b806003c31083c017170a6e983bc9 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
#ifndef __NN__
#define __NN__

#include <stdlib.h>
#include <assert.h>
#include <stdio.h>
#include <stdint.h>
#include <string.h>
#include <math.h>
#include <unistd.h>
#include <openblas/cblas.h>

typedef struct Layer {
    double *weights, *bias;
    double (*activation)(double x);
    size_t neurons, input_nodes;
} Layer;

void nn_network_init_weights(Layer *network, size_t nmemb, size_t input_cols);
void nn_network_free_weights(Layer *network, size_t nmemb);

void nn_layer_map_activation(
        double (*activation)(double),
        double *aout, size_t aout_shape[2],
        double *zout, size_t zout_shape[2]);

void nn_layer_forward(Layer layer, double *out, size_t out_shape[2], double *input, size_t input_shape[2]);
void nn_layer_backward(
        double *weights, size_t weigths_shape[2],
        double *delta, size_t dcost_cols,
        double *out_prev, size_t out_cols,
        Layer layer, double alpha);

double sigmoid(double x);
double relu(double x);
double identity(double x);


void nn_forward(
        double **aout, double **zout,
        double *input, size_t input_shape[2],
        Layer network[], size_t network_size);

void nn_backwad(
        double **weights,
        double **zout, double **outs, size_t n_rows,
        Layer network[], size_t network_size,
        double (cost_derivative)(double, double));

void nn_layer_out_delta(
        double *delta, size_t delta_cols,
        double *error, size_t error_cols,
        double *zout, size_t zout_cols,
        double (*activation_derivative)(double));//TODO

void nn_layer_hidden_delta(
        double *delta, size_t delta_cols,
        double *delta_next, size_t delta_next_cols,
        double *weigths_next, size_t weigths_shape[2],
        double *zout, size_t zout_cols,
        double (*activation_derivative)(double));//TODO
#endif
Feel free to download, copy and edit any repo