From bab6d665d98668f67686a3085d7de46204f67366 Mon Sep 17 00:00:00 2001 From: jvech Date: Sun, 3 Sep 2023 16:41:17 -0500 Subject: fix: write network parameter fixed --- .gitignore | 1 + Makefile | 7 ++++--- src/nn.c | 3 ++- utils/settings.cfg | 18 +++++++----------- 4 files changed, 14 insertions(+), 15 deletions(-) diff --git a/.gitignore b/.gitignore index bdbc6b0..e18ab1f 100644 --- a/.gitignore +++ b/.gitignore @@ -7,3 +7,4 @@ doc/* !doc/*.tex *.gdb *.gpi +*.bin diff --git a/Makefile b/Makefile index 7cbbd9c..457dc62 100644 --- a/Makefile +++ b/Makefile @@ -23,11 +23,12 @@ build: $(OBJS) run: build @./${BIN} train data/sample_data.json | tee data/train_history.txt - @./${BIN} predict data/sample_data.json | jq -r '.[] | [values[] as $$val | $$val] | @tsv' > ./data/net_data.tsv - @gnuplot -p utils/plot.gpi + @./${BIN} predict data/sample_data.json | jq -r '.[] | [values[] as $$val | $$val] | @tsv' > data/net_data.tsv + @jq -r '.[] | [values[] as $$val | $$val] | @tsv' data/sample_data.json > data/sample_data.tsv + @gnuplot utils/plot.gpi debug: build - gdb -x utils/commands.gdb --tui --args ${BIN} train -a 230 data/sample_data.json -e 150 + gdb -x utils/commands.gdb --tui --args ${BIN} train data/sample_data.json gdb -x utils/commands.gdb --tui --args ${BIN} predict data/sample_data.json clean: diff --git a/src/nn.c b/src/nn.c index c7b8ac3..ce033e2 100644 --- a/src/nn.c +++ b/src/nn.c @@ -306,6 +306,7 @@ void nn_network_read_weights(char *filepath, Layer *network, size_t network_size if (ret != shape[1]) goto nn_network_read_weights_error; } + fclose(fp); return; nn_network_read_weights_error: @@ -331,7 +332,7 @@ void nn_network_write_weights(char *filepath, Layer *network, size_t network_siz ret = fwrite(network[i].weights, sizeof(double), size, fp); if (ret != size) goto nn_network_write_weights_error; - ret = fwrite(network[i].weights, sizeof(double), network[i].neurons, fp); + ret = fwrite(network[i].bias, sizeof(double), network[i].neurons, fp); if (ret != network[i].neurons) goto nn_network_write_weights_error; } fclose(fp); diff --git a/utils/settings.cfg b/utils/settings.cfg index 00bf040..eee85fc 100644 --- a/utils/settings.cfg +++ b/utils/settings.cfg @@ -1,20 +1,16 @@ [net] loss = square ; options (square) -epochs = 40 ; comment -alpha = 1e-4 +epochs = 200 ; comment +alpha = 1e-2 weights_path = utils/weights.bin -inputs = x, y -labels = z +inputs = x +labels = y ; activation options (relu, sigmoid, softplus, leaky_relu) [layer] -neurons=3 -activation=relu - -[layer] -neurons=4 -activation=relu +neurons=20 +activation=sigmoid [outlayer] -activation = sigmoid +activation = sigmoid -- cgit v1.2.3-70-g09d2