From 7300f3b643a6b2fcc66d4f77c5097d263dabc554 Mon Sep 17 00:00:00 2001 From: jvech Date: Tue, 16 Jul 2024 07:11:51 -0500 Subject: add: tests scripts implemented --- tests/plots.gpi | 37 +++++++++++++++++++++++++++++++++++++ 1 file changed, 37 insertions(+) create mode 100644 tests/plots.gpi (limited to 'tests/plots.gpi') diff --git a/tests/plots.gpi b/tests/plots.gpi new file mode 100644 index 0000000..85fbf98 --- /dev/null +++ b/tests/plots.gpi @@ -0,0 +1,37 @@ +#!/usr/bin/gnuplot +set term pngcairo size 1080,720 +set output 'tests/network_accuracy.png' +set multiplot layout 2, 2 +set grid + +json2tsv = "jq -r '.[] | [.[]] | @tsv' %s" +train_cmd = "<./ml train %s -c %s" +predict_cmd = "<./ml predict %s -c %s | ".sprintf(json2tsv, "-") + +# -- +data_gauss1d = "data/gauss1d.json" +data_xor = "data/xor.json" + +# -- +arch_gauss1d = "tests/architectures/gauss1d.cfg" +arch_xor = "tests/architectures/xor.cfg" + + +set ylabel arch_gauss1d +set logscale x +plot sprintf(train_cmd, data_gauss1d, arch_gauss1d) u 2:4 with lines title 'loss' +unset logscale +unset ylabel + +plot sprintf(predict_cmd, data_gauss1d, arch_gauss1d) with lines title 'network',\ + "<".sprintf(json2tsv, data_gauss1d) with lines title 'original' + +set ylabel arch_xor +set logscale x +plot sprintf(train_cmd, data_xor, arch_xor) u 2:4 with lines title 'loss' +unset logscale +unset ylabel + +set table "/dev/stdout" +plot "<".sprintf(json2tsv, data_xor) using 1:2:3 with table,\ + sprintf(predict_cmd, data_xor, arch_xor) using 3 with table -- cgit v1.2.3-70-g09d2