From ebd66e65bf18574fa8905d7b0ae3fbb85bfc9e06 Mon Sep 17 00:00:00 2001 From: jvech Date: Tue, 6 Aug 2024 14:29:42 -0500 Subject: add: file parsing improved Things implemented: * json_read() must die if the key does not exist or the value type is wrong. * on predict command input should be shown exactly the same * float precision CLI option should be added. --- tests/architectures/gauss2d.cfg | 19 +++++++++++++++++++ tests/plots.gpi | 27 ++++++++++++++++----------- 2 files changed, 35 insertions(+), 11 deletions(-) create mode 100644 tests/architectures/gauss2d.cfg (limited to 'tests') diff --git a/tests/architectures/gauss2d.cfg b/tests/architectures/gauss2d.cfg new file mode 100644 index 0000000..d9236ad --- /dev/null +++ b/tests/architectures/gauss2d.cfg @@ -0,0 +1,19 @@ +[net] +loss = square ; options (square) +epochs = 1000 ; comment +alpha = 2e-4 +weights_path = data/gauss2d.bin +inputs = x,y +labels = z + +; activation options (relu, sigmoid, softplus, leaky_relu) + +[layer] +neurons=20 +activation=sigmoid +[layer] +neurons=10 +activation=relu + +[outlayer] +activation = sigmoid diff --git a/tests/plots.gpi b/tests/plots.gpi index 4fd11b0..2101520 100644 --- a/tests/plots.gpi +++ b/tests/plots.gpi @@ -1,7 +1,7 @@ #!/usr/bin/gnuplot -set term pngcairo size 1080,720 +set term pngcairo size 1080,360*3 set output 'tests/network_accuracy.png' -set multiplot layout 2, 2 +set multiplot layout 3, 2 set grid json2tsv = "jq -r '.[] | [.[]] | @tsv' %s" @@ -12,11 +12,13 @@ predict_cmd = "<./ml predict %s -c %s | ".sprintf(json2tsv, "-") data_gauss1d = "data/gauss1d.json" data_xor = "data/xor.json" data_sine = "data/sine.json" +data_gauss2d = "data/gauss2d.json" # -- arch_gauss1d = "tests/architectures/gauss1d.cfg" arch_xor = "tests/architectures/xor.cfg" arch_sine = "tests/architectures/sine.cfg" +arch_gauss2d = "tests/architectures/gauss2d.cfg" set ylabel arch_gauss1d @@ -28,15 +30,6 @@ unset ylabel plot sprintf(predict_cmd, data_gauss1d, arch_gauss1d) with lines title 'network',\ "<".sprintf(json2tsv, data_gauss1d) with lines title 'original' -#set ylabel arch_xor -#set logscale x -#plot sprintf(train_cmd, data_xor, arch_xor) u 2:4 with lines title 'loss' -#unset logscale -#unset ylabel -# -#set table "/dev/stdout" -#plot "<".sprintf(json2tsv, data_xor) using 1:2:3 with table,\ -# sprintf(predict_cmd, data_xor, arch_xor) using 3 with table set ylabel arch_sine set logscale x @@ -46,3 +39,15 @@ unset ylabel plot sprintf(predict_cmd, data_sine, arch_sine) with lines title 'network',\ "<".sprintf(json2tsv, data_sine) with lines title 'original' + + +set ylabel arch_gauss2d +set logscale x +plot sprintf(train_cmd, data_gauss2d, arch_gauss2d) u 2:4 with lines title 'loss' +unset logscale +unset ylabel + +set view 45,30 +splot "<".sprintf(json2tsv, data_gauss2d) using 1:2:3 with lines title 'network',\ + sprintf(predict_cmd, data_gauss2d, arch_gauss2d) with lines title 'original' +unset multiplot -- cgit v1.2.3-70-g09d2