diff options
Diffstat (limited to 'src/nn.c')
-rw-r--r-- | src/nn.c | 54 |
1 files changed, 39 insertions, 15 deletions
@@ -71,7 +71,8 @@ void nn_network_train( Layer network[], size_t network_size, double *input, size_t input_shape[2], double *labels, size_t labels_shape[2], - struct Cost cost, size_t epochs, double alpha) + struct Cost cost, size_t epochs, + size_t batch_size, double alpha) { assert(input_shape[0] == labels_shape[0] && "label samples don't correspond with input samples\n"); @@ -86,27 +87,50 @@ void nn_network_train( size_t samples = input_shape[0]; for (size_t l = 0; l < network_size; l++) { - outs[l] = calloc(samples * network[l].neurons, sizeof(double)); - zouts[l] = calloc(samples * network[l].neurons, sizeof(double)); + outs[l] = calloc(batch_size * network[l].neurons, sizeof(double)); + zouts[l] = calloc(batch_size * network[l].neurons, sizeof(double)); weights[l] = malloc(network[l].input_nodes * network[l].neurons * sizeof(double)); biases[l] = malloc(network[l].neurons * sizeof(double)); + if (!outs[l] || !zouts || !weights[l] || !biases) goto nn_network_train_error; + + memcpy(weights[l], network[l].weights, sizeof(double) * network[l].input_nodes * network[l].neurons); memcpy(biases[l], network[l].bias, sizeof(double) * network[l].neurons); } + + size_t batch_input_shape[2] = {batch_size, input_shape[1]}; + size_t batch_labels_shape[2] = {batch_size, labels_shape[1]}; + size_t n_batches = input_shape[0] / batch_size; + if (samples % batch_size) { + n_batches++; + } for (size_t epoch = 0; epoch < epochs; epoch++) { - nn_forward(outs, zouts, input, input_shape, network, network_size); - nn_backward( - weights, biases, - zouts, outs, - input, input_shape, - labels, labels_shape, - network, network_size, - cost.dfunc_out, alpha); - double *net_out = outs[network_size - 1]; - fprintf(stdout, "epoch: %zu \t loss: %6.6lf\n", - epoch, get_avg_loss(labels, net_out, labels_shape, cost.func)); + for (size_t batch_idx = 0; batch_idx < n_batches; batch_idx++) { + size_t index = batch_size * batch_idx; + + double *input_batch = input + index * input_shape[1]; + double *labels_batch = labels + index * labels_shape[1]; + + if (batch_idx == n_batches - 1 && samples % batch_size) { + batch_input_shape[0] = samples % batch_size; + batch_labels_shape[0] = samples % batch_size; + } + + nn_forward(outs, zouts, input_batch, batch_input_shape, network, network_size); + nn_backward( + weights, biases, + zouts, outs, + input_batch, batch_input_shape, + labels_batch, batch_labels_shape, + network, network_size, + cost.dfunc_out, alpha); + double *net_out = outs[network_size - 1]; + fprintf(stdout, "epoch: %g \t loss: %6.6lf\n", + epoch + (float)batch_idx / n_batches, + get_avg_loss(labels, net_out, batch_labels_shape, cost.func)); + } } for (size_t l = 0; l < network_size; l++) { @@ -417,7 +441,7 @@ void fill_random_weights(double *weights, double *bias, size_t rows, size_t cols } for (size_t i = 0; i < cols; i++) { - bias[i] = (double)random_bias[i] / (double)INT64_MAX * 2; + bias[i] = (double)random_bias[i] / (double)INT64_MAX * 2; } free(random_weights); |