aboutsummaryrefslogtreecommitdiff
path: root/src/nn.c
diff options
context:
space:
mode:
Diffstat (limited to 'src/nn.c')
-rw-r--r--src/nn.c54
1 files changed, 39 insertions, 15 deletions
diff --git a/src/nn.c b/src/nn.c
index 916803e..56c35fc 100644
--- a/src/nn.c
+++ b/src/nn.c
@@ -71,7 +71,8 @@ void nn_network_train(
Layer network[], size_t network_size,
double *input, size_t input_shape[2],
double *labels, size_t labels_shape[2],
- struct Cost cost, size_t epochs, double alpha)
+ struct Cost cost, size_t epochs,
+ size_t batch_size, double alpha)
{
assert(input_shape[0] == labels_shape[0] && "label samples don't correspond with input samples\n");
@@ -86,27 +87,50 @@ void nn_network_train(
size_t samples = input_shape[0];
for (size_t l = 0; l < network_size; l++) {
- outs[l] = calloc(samples * network[l].neurons, sizeof(double));
- zouts[l] = calloc(samples * network[l].neurons, sizeof(double));
+ outs[l] = calloc(batch_size * network[l].neurons, sizeof(double));
+ zouts[l] = calloc(batch_size * network[l].neurons, sizeof(double));
weights[l] = malloc(network[l].input_nodes * network[l].neurons * sizeof(double));
biases[l] = malloc(network[l].neurons * sizeof(double));
+ if (!outs[l] || !zouts || !weights[l] || !biases) goto nn_network_train_error;
+
+
memcpy(weights[l], network[l].weights, sizeof(double) * network[l].input_nodes * network[l].neurons);
memcpy(biases[l], network[l].bias, sizeof(double) * network[l].neurons);
}
+
+ size_t batch_input_shape[2] = {batch_size, input_shape[1]};
+ size_t batch_labels_shape[2] = {batch_size, labels_shape[1]};
+ size_t n_batches = input_shape[0] / batch_size;
+ if (samples % batch_size) {
+ n_batches++;
+ }
for (size_t epoch = 0; epoch < epochs; epoch++) {
- nn_forward(outs, zouts, input, input_shape, network, network_size);
- nn_backward(
- weights, biases,
- zouts, outs,
- input, input_shape,
- labels, labels_shape,
- network, network_size,
- cost.dfunc_out, alpha);
- double *net_out = outs[network_size - 1];
- fprintf(stdout, "epoch: %zu \t loss: %6.6lf\n",
- epoch, get_avg_loss(labels, net_out, labels_shape, cost.func));
+ for (size_t batch_idx = 0; batch_idx < n_batches; batch_idx++) {
+ size_t index = batch_size * batch_idx;
+
+ double *input_batch = input + index * input_shape[1];
+ double *labels_batch = labels + index * labels_shape[1];
+
+ if (batch_idx == n_batches - 1 && samples % batch_size) {
+ batch_input_shape[0] = samples % batch_size;
+ batch_labels_shape[0] = samples % batch_size;
+ }
+
+ nn_forward(outs, zouts, input_batch, batch_input_shape, network, network_size);
+ nn_backward(
+ weights, biases,
+ zouts, outs,
+ input_batch, batch_input_shape,
+ labels_batch, batch_labels_shape,
+ network, network_size,
+ cost.dfunc_out, alpha);
+ double *net_out = outs[network_size - 1];
+ fprintf(stdout, "epoch: %g \t loss: %6.6lf\n",
+ epoch + (float)batch_idx / n_batches,
+ get_avg_loss(labels, net_out, batch_labels_shape, cost.func));
+ }
}
for (size_t l = 0; l < network_size; l++) {
@@ -417,7 +441,7 @@ void fill_random_weights(double *weights, double *bias, size_t rows, size_t cols
}
for (size_t i = 0; i < cols; i++) {
- bias[i] = (double)random_bias[i] / (double)INT64_MAX * 2;
+ bias[i] = (double)random_bias[i] / (double)INT64_MAX * 2;
}
free(random_weights);
Feel free to download, copy and edit any repo