diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/main.c | 2 | ||||
-rw-r--r-- | src/nn.c | 54 | ||||
-rw-r--r-- | src/nn.h | 3 | ||||
-rw-r--r-- | src/util.c | 13 | ||||
-rw-r--r-- | src/util.h | 1 |
5 files changed, 54 insertions, 19 deletions
@@ -91,6 +91,7 @@ int main(int argc, char *argv[]) { char default_config_path[512], *env_config_path; struct Configs ml_configs = { .epochs = 100, + .batch_size = 32, .alpha = 1e-5, .config_filepath = "", .network_size = 0, @@ -138,6 +139,7 @@ int main(int argc, char *argv[]) { y.data, y.shape, load_loss(ml_configs), ml_configs.epochs, + ml_configs.batch_size, ml_configs.alpha); nn_network_write_weights(ml_configs.weights_filepath, network, ml_configs.network_size); fprintf(stderr, "weights saved on '%s'\n", ml_configs.weights_filepath); @@ -71,7 +71,8 @@ void nn_network_train( Layer network[], size_t network_size, double *input, size_t input_shape[2], double *labels, size_t labels_shape[2], - struct Cost cost, size_t epochs, double alpha) + struct Cost cost, size_t epochs, + size_t batch_size, double alpha) { assert(input_shape[0] == labels_shape[0] && "label samples don't correspond with input samples\n"); @@ -86,27 +87,50 @@ void nn_network_train( size_t samples = input_shape[0]; for (size_t l = 0; l < network_size; l++) { - outs[l] = calloc(samples * network[l].neurons, sizeof(double)); - zouts[l] = calloc(samples * network[l].neurons, sizeof(double)); + outs[l] = calloc(batch_size * network[l].neurons, sizeof(double)); + zouts[l] = calloc(batch_size * network[l].neurons, sizeof(double)); weights[l] = malloc(network[l].input_nodes * network[l].neurons * sizeof(double)); biases[l] = malloc(network[l].neurons * sizeof(double)); + if (!outs[l] || !zouts || !weights[l] || !biases) goto nn_network_train_error; + + memcpy(weights[l], network[l].weights, sizeof(double) * network[l].input_nodes * network[l].neurons); memcpy(biases[l], network[l].bias, sizeof(double) * network[l].neurons); } + + size_t batch_input_shape[2] = {batch_size, input_shape[1]}; + size_t batch_labels_shape[2] = {batch_size, labels_shape[1]}; + size_t n_batches = input_shape[0] / batch_size; + if (samples % batch_size) { + n_batches++; + } for (size_t epoch = 0; epoch < epochs; epoch++) { - nn_forward(outs, zouts, input, input_shape, network, network_size); - nn_backward( - weights, biases, - zouts, outs, - input, input_shape, - labels, labels_shape, - network, network_size, - cost.dfunc_out, alpha); - double *net_out = outs[network_size - 1]; - fprintf(stdout, "epoch: %zu \t loss: %6.6lf\n", - epoch, get_avg_loss(labels, net_out, labels_shape, cost.func)); + for (size_t batch_idx = 0; batch_idx < n_batches; batch_idx++) { + size_t index = batch_size * batch_idx; + + double *input_batch = input + index * input_shape[1]; + double *labels_batch = labels + index * labels_shape[1]; + + if (batch_idx == n_batches - 1 && samples % batch_size) { + batch_input_shape[0] = samples % batch_size; + batch_labels_shape[0] = samples % batch_size; + } + + nn_forward(outs, zouts, input_batch, batch_input_shape, network, network_size); + nn_backward( + weights, biases, + zouts, outs, + input_batch, batch_input_shape, + labels_batch, batch_labels_shape, + network, network_size, + cost.dfunc_out, alpha); + double *net_out = outs[network_size - 1]; + fprintf(stdout, "epoch: %g \t loss: %6.6lf\n", + epoch + (float)batch_idx / n_batches, + get_avg_loss(labels, net_out, batch_labels_shape, cost.func)); + } } for (size_t l = 0; l < network_size; l++) { @@ -417,7 +441,7 @@ void fill_random_weights(double *weights, double *bias, size_t rows, size_t cols } for (size_t i = 0; i < cols; i++) { - bias[i] = (double)random_bias[i] / (double)INT64_MAX * 2; + bias[i] = (double)random_bias[i] / (double)INT64_MAX * 2; } free(random_weights); @@ -52,7 +52,8 @@ void nn_network_train( Layer network[], size_t network_size, double *input, size_t input_shape[2], double *labels, size_t labels_shape[2], - struct Cost cost, size_t epochs, double alpha); + struct Cost cost, size_t epochs, + size_t batch_size, double alpha); void nn_layer_map_activation( double (*activation)(double), @@ -97,11 +97,12 @@ void usage(int exit_code) "Options:\n" " -h, --help Show this message\n" " -f, --format=FORMAT Define input or output FILE format if needed\n" + " -O, --only-out Don't show input fields (only works with predict)\n" " -a, --alpha=ALPHA Learning rate (only works with train)\n" + " -b, --batch=INT Select batch size [default: 32] (only works with train)\n" + " -c, --config=FILE Configuration filepath [default=~/.config/ml/ml.cfg]\n" " -e, --epochs=EPOCHS Epochs to train the model (only works with train)\n" " -o, --output=FILE Output file (only works with predict)\n" - " -O, --only-out Don't show input fields (only works with predict)\n" - " -c, --config=FILE Configuration filepath [default=~/.config/ml/ml.cfg]\n" " -p, --precision=INT Decimals output precision (only works with predict)\n" " [default=auto]\n" "\n" @@ -117,6 +118,7 @@ void util_load_cli(struct Configs *ml, int argc, char *argv[]) {"version", no_argument, 0, 'v'}, {"format", required_argument, 0, 'f'}, {"epochs", required_argument, 0, 'e'}, + {"batch", required_argument, 0, 'b'}, {"alpha", required_argument, 0, 'a'}, {"output", required_argument, 0, 'o'}, {"config", required_argument, 0, 'c'}, @@ -127,7 +129,7 @@ void util_load_cli(struct Configs *ml, int argc, char *argv[]) int c; while (1) { - c = getopt_long(argc, argv, "hvOc:e:a:o:i:f:p:", long_opts, NULL); + c = getopt_long(argc, argv, "hvOc:e:a:o:i:f:p:b:", long_opts, NULL); if (c == -1) { break; @@ -154,6 +156,10 @@ void util_load_cli(struct Configs *ml, int argc, char *argv[]) case 'p': ml->decimal_precision = (!strcmp("auto", optarg))? -1: (int)atoi(optarg); break; + case 'b': + if (atoi(optarg) <= 0) die("util_load_cli() Error: batch size must be greater than 0"); + ml->batch_size = (size_t)atol(optarg); + break; case 'h': usage(0); break; @@ -316,6 +322,7 @@ void load_net_cfgs(struct Configs *cfg, char *key, char *value, char *strtok_ptr if (!strcmp(key, "weights_path")) cfg->weights_filepath = e_strdup(value); else if (!strcmp(key, "loss")) cfg->loss = e_strdup(value); else if (!strcmp(key, "epochs")) cfg->epochs = (size_t)atol(value); + else if (!strcmp(key, "batch")) cfg->batch_size = (size_t)atol(value); else if (!strcmp(key, "alpha")) cfg->alpha = (double)atof(value); else if (!strcmp(key, "inputs")) cfg->input_keys = config_read_values(&(cfg->n_input_keys), value, &strtok_ptr); else if (!strcmp(key, "labels")) cfg->label_keys = config_read_values(&(cfg->n_label_keys), value, &strtok_ptr); @@ -7,6 +7,7 @@ struct Configs { /* net cfgs */ size_t epochs; + size_t batch_size; double alpha; char *loss; char **input_keys, **label_keys; |