diff --git a/.gitignore b/.gitignore index 66c97be110009856c9648c45dac402b12daa0619..e4420b30c1c50c3d6cd4fe1943f860a3f05ffbc7 100644 --- a/.gitignore +++ b/.gitignore @@ -214,4 +214,5 @@ sympy-plots-for-*.tex/ *.upa *.upb -# pythontex \ No newline at end of file +# pythontex +/.luarc.json diff --git a/dev/rebuttal/support.pdf b/dev/rebuttal/support.pdf index 28566e51663c31caef6c495e8b2c13c3d078d88e..b6a6ba86c97bcfaea8e5405c50a1c4d88d0b1914 100644 Binary files a/dev/rebuttal/support.pdf and b/dev/rebuttal/support.pdf differ diff --git a/dev/rebuttal/support.qmd b/dev/rebuttal/support.qmd index 0658fdc26760c3e1d54f15e22cca2636363396a8..2153272941537d12b5d9687e823ce0da5619a5d6 100644 --- a/dev/rebuttal/support.qmd +++ b/dev/rebuttal/support.qmd @@ -2,11 +2,28 @@ format: pdf: fontsize: 10pt + geometry: + - top=10mm + - bottom=10mm + - left=10mm + - right=10mm --- +\pagenumbering{gobble} -{#fig-mnist-1 width="80%"} +::: {#fig-fmnist layout="[[-1,10,-1], [-1,10,-1], [-1,10,-1], [20,-1,20,-1,20]]"} -{#fig-mnist-2 width="80%"} +{#fig-mnist-1 width="50%"} -<!-- {#fig-mnist-3 width="75%"} --> \ No newline at end of file +{#fig-mnist-2 width="50%"} + +{#fig-mnist-3 width="50%"} + +{#fig-fmnist-dress width="5cm"} + +{#fig-fmnist-pullover width="5cm"} + +{#fig-fmnist-boot width="5cm"} + +Qualitative examples for MNIST and Fashion-MNIST. **MNIST** (@fig-mnist-1 to @fig-mnist-3): Top row: ECCCo. Bottom row: Wachter. The different underlying models across columns: (a) MLP, (b) Small Ensemble $n=5$, (c) Large Ensemble $n=50$, (d) LeNet-5, (e) JEM, (f) JEM Ensemble. **Fashion-MNIST** (@fig-fmnist-dress to @fig-fmnist-boot): The underlying classifier is a small ensemble of 5 MLPs with one hidden layer. Counterfactuals are generated by ECCCo. +::: \ No newline at end of file diff --git a/dev/rebuttal/www/fmnist_boot.png b/dev/rebuttal/www/fmnist_boot.png index 4ded1dcdfc2163482e8e09ef3479b839ac9df280..67b84044a1b051a62dbd3fcb5e125ef6ee6c051d 100644 Binary files a/dev/rebuttal/www/fmnist_boot.png and b/dev/rebuttal/www/fmnist_boot.png differ diff --git a/dev/rebuttal/www/fmnist_boot_lenet.png b/dev/rebuttal/www/fmnist_boot_lenet.png new file mode 100644 index 0000000000000000000000000000000000000000..86cece2b526fe6f4598ba78e1143d6cfa58f5295 Binary files /dev/null and b/dev/rebuttal/www/fmnist_boot_lenet.png differ diff --git a/dev/rebuttal/www/fmnist_dress.png b/dev/rebuttal/www/fmnist_dress.png index 0904b70aed7beaf52cce9c18c6e43193a53d1741..f8bf885e5443565d4592dec567465a879c61aed1 100644 Binary files a/dev/rebuttal/www/fmnist_dress.png and b/dev/rebuttal/www/fmnist_dress.png differ diff --git a/dev/rebuttal/www/fmnist_pullover.png b/dev/rebuttal/www/fmnist_pullover.png new file mode 100644 index 0000000000000000000000000000000000000000..1c5272b75cecb729e3dd5dedfbe72510e9c3770d Binary files /dev/null and b/dev/rebuttal/www/fmnist_pullover.png differ diff --git a/dev/rebuttal/www/fmnist_pullover_lenet.png b/dev/rebuttal/www/fmnist_pullover_lenet.png new file mode 100644 index 0000000000000000000000000000000000000000..d97e1bcedaf0dd17da9d9dfc27715a4c76a4ea5a Binary files /dev/null and b/dev/rebuttal/www/fmnist_pullover_lenet.png differ diff --git a/dev/rebuttal/www/mnist_0to3_28.png b/dev/rebuttal/www/mnist_0to3_28.png new file mode 100644 index 0000000000000000000000000000000000000000..2aeca96e4286d3c8b1699419547fcf9ab340e324 Binary files /dev/null and b/dev/rebuttal/www/mnist_0to3_28.png differ diff --git a/dev/rebuttal/www/mnist_0to3_29.png b/dev/rebuttal/www/mnist_0to3_29.png new file mode 100644 index 0000000000000000000000000000000000000000..eff2d038a2c0952c4a113725568cef5be3aea05a Binary files /dev/null and b/dev/rebuttal/www/mnist_0to3_29.png differ diff --git a/dev/rebuttal/www/mnist_0to3_30.png b/dev/rebuttal/www/mnist_0to3_30.png new file mode 100644 index 0000000000000000000000000000000000000000..68c3c3d5c7ec208fd257c82d375ade6709ab560f Binary files /dev/null and b/dev/rebuttal/www/mnist_0to3_30.png differ diff --git a/dev/rebuttal/www/mnist_1to4_25.png b/dev/rebuttal/www/mnist_1to4_25.png new file mode 100644 index 0000000000000000000000000000000000000000..0531c83c1e209f22033c5cf744caa87ece8d71b1 Binary files /dev/null and b/dev/rebuttal/www/mnist_1to4_25.png differ diff --git a/dev/rebuttal/www/mnist_1to4_26.png b/dev/rebuttal/www/mnist_1to4_26.png new file mode 100644 index 0000000000000000000000000000000000000000..2d48d5997b6059c5320f54c30bb0b7e059eb9ab6 Binary files /dev/null and b/dev/rebuttal/www/mnist_1to4_26.png differ diff --git a/dev/rebuttal/www/mnist_1to4_27.png b/dev/rebuttal/www/mnist_1to4_27.png new file mode 100644 index 0000000000000000000000000000000000000000..379b6c173e0948431e24fe36f630599e2e30150f Binary files /dev/null and b/dev/rebuttal/www/mnist_1to4_27.png differ diff --git a/dev/rebuttal/www/mnist_1to7_22.png b/dev/rebuttal/www/mnist_1to7_22.png new file mode 100644 index 0000000000000000000000000000000000000000..7b9a940569e42fd569d70f34b9e6ac1ee081a15d Binary files /dev/null and b/dev/rebuttal/www/mnist_1to7_22.png differ diff --git a/dev/rebuttal/www/mnist_1to7_23.png b/dev/rebuttal/www/mnist_1to7_23.png new file mode 100644 index 0000000000000000000000000000000000000000..b87e0ff2ec20eb0d8efc392434dce3d8c1b60f46 Binary files /dev/null and b/dev/rebuttal/www/mnist_1to7_23.png differ diff --git a/dev/rebuttal/www/mnist_1to7_24.png b/dev/rebuttal/www/mnist_1to7_24.png new file mode 100644 index 0000000000000000000000000000000000000000..149b2b5d208f9b55d425334a168bc7bb29012565 Binary files /dev/null and b/dev/rebuttal/www/mnist_1to7_24.png differ diff --git a/dev/rebuttal/www/mnist_2to3_13.png b/dev/rebuttal/www/mnist_2to3_13.png new file mode 100644 index 0000000000000000000000000000000000000000..b8a2e13d67274fd75bd7954289da7007f0fe582c Binary files /dev/null and b/dev/rebuttal/www/mnist_2to3_13.png differ diff --git a/dev/rebuttal/www/mnist_2to3_14.png b/dev/rebuttal/www/mnist_2to3_14.png new file mode 100644 index 0000000000000000000000000000000000000000..c24f72d9a15a41ceb1cd999e1d8bf884990b44d3 Binary files /dev/null and b/dev/rebuttal/www/mnist_2to3_14.png differ diff --git a/dev/rebuttal/www/mnist_2to3_15.png b/dev/rebuttal/www/mnist_2to3_15.png new file mode 100644 index 0000000000000000000000000000000000000000..28cc09447a152fb57ec4de68333f1d779f55d0b0 Binary files /dev/null and b/dev/rebuttal/www/mnist_2to3_15.png differ diff --git a/dev/rebuttal/www/mnist_4to1_10.png b/dev/rebuttal/www/mnist_4to1_10.png new file mode 100644 index 0000000000000000000000000000000000000000..dc2ce236891f8fdbe1a42987fe4cf54ed708a6fa Binary files /dev/null and b/dev/rebuttal/www/mnist_4to1_10.png differ diff --git a/dev/rebuttal/www/mnist_4to1_11.png b/dev/rebuttal/www/mnist_4to1_11.png index e668a3fa2150240e71dfecc2b02bf8f90dae7eb7..888d65979cdb1c0618598910a770f740acb96dce 100644 Binary files a/dev/rebuttal/www/mnist_4to1_11.png and b/dev/rebuttal/www/mnist_4to1_11.png differ diff --git a/dev/rebuttal/www/mnist_4to1_12.png b/dev/rebuttal/www/mnist_4to1_12.png index effb697d4b1875ef27e31ae3525960f3fec95914..bebc47574dd1ee9d1f731bd48d1a4fa0ec8801cf 100644 Binary files a/dev/rebuttal/www/mnist_4to1_12.png and b/dev/rebuttal/www/mnist_4to1_12.png differ diff --git a/dev/rebuttal/www/mnist_5to8_16.png b/dev/rebuttal/www/mnist_5to8_16.png new file mode 100644 index 0000000000000000000000000000000000000000..3b40f09f54745ae24f3cbb86f6c3b87183e2ddc3 Binary files /dev/null and b/dev/rebuttal/www/mnist_5to8_16.png differ diff --git a/dev/rebuttal/www/mnist_5to8_17.png b/dev/rebuttal/www/mnist_5to8_17.png new file mode 100644 index 0000000000000000000000000000000000000000..03c04caa13550a531848fccc16d56aa67039a34b Binary files /dev/null and b/dev/rebuttal/www/mnist_5to8_17.png differ diff --git a/dev/rebuttal/www/mnist_5to8_18.png b/dev/rebuttal/www/mnist_5to8_18.png new file mode 100644 index 0000000000000000000000000000000000000000..86a38ee5cfa732d92d4449fe6d4bc93203213610 Binary files /dev/null and b/dev/rebuttal/www/mnist_5to8_18.png differ diff --git a/dev/rebuttal/www/mnist_6to0_7.png b/dev/rebuttal/www/mnist_6to0_7.png index 223ff3f798b56b02db1ab5e9bec8b1ef3a93526e..cd3cee5812c087dede53298a913380b067891c11 100644 Binary files a/dev/rebuttal/www/mnist_6to0_7.png and b/dev/rebuttal/www/mnist_6to0_7.png differ diff --git a/dev/rebuttal/www/mnist_6to0_8.png b/dev/rebuttal/www/mnist_6to0_8.png index 3b03fae8b353065f957b4f06215a541e200bb69f..5dad24bc7b60fdd48ab0dbe4eff2ad1bceb194ec 100644 Binary files a/dev/rebuttal/www/mnist_6to0_8.png and b/dev/rebuttal/www/mnist_6to0_8.png differ diff --git a/dev/rebuttal/www/mnist_6to0_9.png b/dev/rebuttal/www/mnist_6to0_9.png index e392bb9c1d92f727adc159bcd7fe9c2ae913da1f..a01cb22373a28399bed4f28bfde80ed506aaa43e 100644 Binary files a/dev/rebuttal/www/mnist_6to0_9.png and b/dev/rebuttal/www/mnist_6to0_9.png differ diff --git a/dev/rebuttal/www/mnist_6to5_19.png b/dev/rebuttal/www/mnist_6to5_19.png new file mode 100644 index 0000000000000000000000000000000000000000..43b98cf6c66451528cc32d646cc831a533aeb3b1 Binary files /dev/null and b/dev/rebuttal/www/mnist_6to5_19.png differ diff --git a/dev/rebuttal/www/mnist_6to5_20.png b/dev/rebuttal/www/mnist_6to5_20.png new file mode 100644 index 0000000000000000000000000000000000000000..816e5088445e3ca2437f781be7e4030de233cbe5 Binary files /dev/null and b/dev/rebuttal/www/mnist_6to5_20.png differ diff --git a/dev/rebuttal/www/mnist_6to5_21.png b/dev/rebuttal/www/mnist_6to5_21.png new file mode 100644 index 0000000000000000000000000000000000000000..9e435fac593b024c281ce1c9c652628bd31f92c3 Binary files /dev/null and b/dev/rebuttal/www/mnist_6to5_21.png differ diff --git a/dev/rebuttal/www/mnist_7to2_4.png b/dev/rebuttal/www/mnist_7to2_4.png new file mode 100644 index 0000000000000000000000000000000000000000..7be60ab7d8bc331821b29140d35d374e87fe98c5 Binary files /dev/null and b/dev/rebuttal/www/mnist_7to2_4.png differ diff --git a/dev/rebuttal/www/mnist_7to2_5.png b/dev/rebuttal/www/mnist_7to2_5.png new file mode 100644 index 0000000000000000000000000000000000000000..2521e2a09e7260e1e4a713b73694ac86da835cd7 Binary files /dev/null and b/dev/rebuttal/www/mnist_7to2_5.png differ diff --git a/dev/rebuttal/www/mnist_7to2_6.png b/dev/rebuttal/www/mnist_7to2_6.png new file mode 100644 index 0000000000000000000000000000000000000000..3c3a27515cd8762c920d85cda0e32a367ca5a4ad Binary files /dev/null and b/dev/rebuttal/www/mnist_7to2_6.png differ diff --git a/dev/rebuttal/www/mnist_9to7_1.png b/dev/rebuttal/www/mnist_9to7_1.png index 7bfde3bc43751ac0495585be65124e77275618d7..ca1b1b890224991b5e94e0bd65834613934b7a35 100644 Binary files a/dev/rebuttal/www/mnist_9to7_1.png and b/dev/rebuttal/www/mnist_9to7_1.png differ diff --git a/dev/rebuttal/www/mnist_9to7_2.png b/dev/rebuttal/www/mnist_9to7_2.png index aa1b70df6347b19a5eab8e0258f2fdc62719b13f..c199846bf9bbe909cf0d8dfdb868847d5be34b5c 100644 Binary files a/dev/rebuttal/www/mnist_9to7_2.png and b/dev/rebuttal/www/mnist_9to7_2.png differ diff --git a/dev/rebuttal/www/mnist_9to7_3.png b/dev/rebuttal/www/mnist_9to7_3.png index 0ebeccfc30999cbf95238ea17f2c4a7e0c3b3414..505b8fc345892dcd4e92e82ce8e9ee0ccdab40a5 100644 Binary files a/dev/rebuttal/www/mnist_9to7_3.png and b/dev/rebuttal/www/mnist_9to7_3.png differ diff --git a/notebooks/fmnist.qmd b/notebooks/fmnist.qmd index bf73c90d1d3a8ce68f8807ca2b18aca394bea493..d10e88b46754b293ea020e8fe84c7fc0cc7b4fc9 100644 --- a/notebooks/fmnist.qmd +++ b/notebooks/fmnist.qmd @@ -177,8 +177,21 @@ model_performance ``` ```{julia} +label_names = [ + "T-shirt/top", + "Trouser", + "Pullover", + "Dress", + "Coat", + "Sandal", + "Shirt", + "Sneaker", + "Bag", + "Boot", +] + function plot_fmnist( - x::Union{AbstractArray, Int}=x_factual, target::Int=target; + x::Int=3, targets::Vector{Int}=[1]; generator = ECCCoGenerator(), rng::Union{Int,AbstractRNG}=Random.GLOBAL_RNG, T::Int = 100, @@ -186,16 +199,15 @@ function plot_fmnist( model=model_dict["Ensemble"], img_height::Int = img_height, test_data::Bool = false, + init::Symbol = :identity, + noise::Float32 = 0.0f0, kwrgs..., ) # Setup: Random.seed!(rng) - if x isa Int - x_fact = counterfactual_data.X[:,rand(findall(labels.==x))][:,:] - else - x_fact = x - end + x_fact = counterfactual_data.X[:,rand(findall(labels.==x))][:,:] + x_fact = pre_process(x_fact, noise=noise) if test_data data = load_mnist_test() @@ -203,30 +215,36 @@ function plot_fmnist( data = counterfactual_data end - ce = generate_counterfactual( - x_fact, target, data, model, generator; - decision_threshold=γ, max_iter=T, - initialization=:identity, - converge_when=:generator_conditions, - ) - # Plot: p1 = Plots.plot( convert2image(MNIST, reshape(x_fact,28,28)), axis=([], false), size=(img_height, img_height), - title="Factual" + title="$(label_names[x+1]) (factual)", + dpi=300 ) - _x = CounterfactualExplanations.counterfactual(ce) - p2 = Plots.plot( - convert2image(MNIST, reshape(_x,28,28)), - axis=([], false), - size=(img_height, img_height), - title="Counterfactual" - ) - - plt = Plots.plot(p1, p2; size=(img_height*2,img_height), layout=(1,2), kwrgs...) + plts = [p1] + + for t in targets + ce = generate_counterfactual( + x_fact, t, data, model, generator; + decision_threshold=γ, max_iter=T, + initialization=init, + converge_when=:generator_conditions, + ) + _x = CounterfactualExplanations.counterfactual(ce) + plt = Plots.plot( + convert2image(MNIST, reshape(_x,28,28)), + axis=([], false), + size=(img_height, img_height), + title="$(label_names[t+1])", + dpi=300 + ) + plts = [plts..., plt] + end + + plt = Plots.plot(plts..., size=(img_height*length(plts),img_height), layout=(1,length(plts)), kwrgs...) return plt end @@ -235,7 +253,27 @@ end ```{julia} -rng = rand(1:10000) -factual = 3 -plt = plot_fmnist(3, 1, rng=rng) -``` \ No newline at end of file +rng = 42 +gen = ECCCoGenerator( + λ=[0.1,0.30,0.30], + nsamples=25, + nmin=10, +) +plts = [] + +plt = plot_fmnist(3, [1]; rng=rng, generator=gen, img_height=200) +display(plt) +push!(plts, plt) +savefig(plt, "dev/rebuttal/www/fmnist_dress.png") + +plt = plot_fmnist(9, [5]; rng=rng, generator=gen, img_height=200) +display(plt) +push!(plts, plt) +savefig(plt, "dev/rebuttal/www/fmnist_boot.png") + +plt = plot_fmnist(2, [4]; rng=rng, generator=gen, img_height=200) +display(plt) +push!(plts, plt) +savefig(plt, "dev/rebuttal/www/fmnist_pullover.png") +``` + diff --git a/notebooks/mnist.qmd b/notebooks/mnist.qmd index 578a2d18753aff439148f68fe0ad363b71a5b6fa..2d88f306a27ef313f91998472fa8b4fa17461300 100644 --- a/notebooks/mnist.qmd +++ b/notebooks/mnist.qmd @@ -613,9 +613,9 @@ display(plt_additional_models) ``` ```{julia} -Random.seed!(2023) +Random.seed!(123) -λ = [0.1,0.215,0.215] +λ = [0.1,0.22,0.22] wachter = WachterGenerator( λ=λ[1], @@ -624,13 +624,18 @@ wachter = WachterGenerator( combos = [ (9,7), + (7,2), (6,0), (4,1), (2,3), (5,8), + (6,5), + (1,7), + (1,4), + (0,3) ] -n_each = 2 +n_each = 3 combos = reduce(vcat, [fill(c, n_each) for c in combos]) plts_eccco = [] @@ -684,7 +689,7 @@ for (i, (factual, target)) in enumerate(combos) p1 = plts_eccco[i] p2 = plts_wachter[i] plt = Plots.plot( - p1, p2, layout=(2,1), plot_title="Factual: $factual, Target: $target", + p1, p2, layout=(2,1), size = (1200, 400), ) display(plt)