Skip to content
Snippets Groups Projects
Commit a6d09097 authored by Pat Alt's avatar Pat Alt
Browse files

done

parent ec9dbfb6
No related branches found
No related tags found
1 merge request!4336 rebuttal
Showing
with 74 additions and 31 deletions
dev/rebuttal/www/mnist_4to1_10.png

26.5 KiB

dev/rebuttal/www/mnist_4to1_11.png

29.6 KiB | W: | H:

dev/rebuttal/www/mnist_4to1_11.png

26.1 KiB | W: | H:

dev/rebuttal/www/mnist_4to1_11.png
dev/rebuttal/www/mnist_4to1_11.png
dev/rebuttal/www/mnist_4to1_11.png
dev/rebuttal/www/mnist_4to1_11.png
  • 2-up
  • Swipe
  • Onion skin
dev/rebuttal/www/mnist_4to1_12.png

29.7 KiB | W: | H:

dev/rebuttal/www/mnist_4to1_12.png

24.6 KiB | W: | H:

dev/rebuttal/www/mnist_4to1_12.png
dev/rebuttal/www/mnist_4to1_12.png
dev/rebuttal/www/mnist_4to1_12.png
dev/rebuttal/www/mnist_4to1_12.png
  • 2-up
  • Swipe
  • Onion skin
dev/rebuttal/www/mnist_5to8_16.png

25.3 KiB

dev/rebuttal/www/mnist_5to8_17.png

25.9 KiB

dev/rebuttal/www/mnist_5to8_18.png

26.6 KiB

dev/rebuttal/www/mnist_6to0_7.png

33.3 KiB | W: | H:

dev/rebuttal/www/mnist_6to0_7.png

27 KiB | W: | H:

dev/rebuttal/www/mnist_6to0_7.png
dev/rebuttal/www/mnist_6to0_7.png
dev/rebuttal/www/mnist_6to0_7.png
dev/rebuttal/www/mnist_6to0_7.png
  • 2-up
  • Swipe
  • Onion skin
dev/rebuttal/www/mnist_6to0_8.png

32 KiB | W: | H:

dev/rebuttal/www/mnist_6to0_8.png

26.2 KiB | W: | H:

dev/rebuttal/www/mnist_6to0_8.png
dev/rebuttal/www/mnist_6to0_8.png
dev/rebuttal/www/mnist_6to0_8.png
dev/rebuttal/www/mnist_6to0_8.png
  • 2-up
  • Swipe
  • Onion skin
dev/rebuttal/www/mnist_6to0_9.png

31.7 KiB | W: | H:

dev/rebuttal/www/mnist_6to0_9.png

25.4 KiB | W: | H:

dev/rebuttal/www/mnist_6to0_9.png
dev/rebuttal/www/mnist_6to0_9.png
dev/rebuttal/www/mnist_6to0_9.png
dev/rebuttal/www/mnist_6to0_9.png
  • 2-up
  • Swipe
  • Onion skin
dev/rebuttal/www/mnist_6to5_19.png

26.7 KiB

dev/rebuttal/www/mnist_6to5_20.png

27.7 KiB

dev/rebuttal/www/mnist_6to5_21.png

26.7 KiB

dev/rebuttal/www/mnist_7to2_4.png

26.9 KiB

dev/rebuttal/www/mnist_7to2_5.png

26.7 KiB

dev/rebuttal/www/mnist_7to2_6.png

25.8 KiB

dev/rebuttal/www/mnist_9to7_1.png

30 KiB | W: | H:

dev/rebuttal/www/mnist_9to7_1.png

26.7 KiB | W: | H:

dev/rebuttal/www/mnist_9to7_1.png
dev/rebuttal/www/mnist_9to7_1.png
dev/rebuttal/www/mnist_9to7_1.png
dev/rebuttal/www/mnist_9to7_1.png
  • 2-up
  • Swipe
  • Onion skin
dev/rebuttal/www/mnist_9to7_2.png

29.3 KiB | W: | H:

dev/rebuttal/www/mnist_9to7_2.png

26.9 KiB | W: | H:

dev/rebuttal/www/mnist_9to7_2.png
dev/rebuttal/www/mnist_9to7_2.png
dev/rebuttal/www/mnist_9to7_2.png
dev/rebuttal/www/mnist_9to7_2.png
  • 2-up
  • Swipe
  • Onion skin
dev/rebuttal/www/mnist_9to7_3.png

31.7 KiB | W: | H:

dev/rebuttal/www/mnist_9to7_3.png

25.9 KiB | W: | H:

dev/rebuttal/www/mnist_9to7_3.png
dev/rebuttal/www/mnist_9to7_3.png
dev/rebuttal/www/mnist_9to7_3.png
dev/rebuttal/www/mnist_9to7_3.png
  • 2-up
  • Swipe
  • Onion skin
......@@ -177,8 +177,21 @@ model_performance
```
```{julia}
label_names = [
"T-shirt/top",
"Trouser",
"Pullover",
"Dress",
"Coat",
"Sandal",
"Shirt",
"Sneaker",
"Bag",
"Boot",
]
function plot_fmnist(
x::Union{AbstractArray, Int}=x_factual, target::Int=target;
x::Int=3, targets::Vector{Int}=[1];
generator = ECCCoGenerator(),
rng::Union{Int,AbstractRNG}=Random.GLOBAL_RNG,
T::Int = 100,
......@@ -186,16 +199,15 @@ function plot_fmnist(
model=model_dict["Ensemble"],
img_height::Int = img_height,
test_data::Bool = false,
init::Symbol = :identity,
noise::Float32 = 0.0f0,
kwrgs...,
)
# Setup:
Random.seed!(rng)
if x isa Int
x_fact = counterfactual_data.X[:,rand(findall(labels.==x))][:,:]
else
x_fact = x
end
x_fact = counterfactual_data.X[:,rand(findall(labels.==x))][:,:]
x_fact = pre_process(x_fact, noise=noise)
if test_data
data = load_mnist_test()
......@@ -203,30 +215,36 @@ function plot_fmnist(
data = counterfactual_data
end
ce = generate_counterfactual(
x_fact, target, data, model, generator;
decision_threshold=γ, max_iter=T,
initialization=:identity,
converge_when=:generator_conditions,
)
# Plot:
p1 = Plots.plot(
convert2image(MNIST, reshape(x_fact,28,28)),
axis=([], false),
size=(img_height, img_height),
title="Factual"
title="$(label_names[x+1]) (factual)",
dpi=300
)
_x = CounterfactualExplanations.counterfactual(ce)
p2 = Plots.plot(
convert2image(MNIST, reshape(_x,28,28)),
axis=([], false),
size=(img_height, img_height),
title="Counterfactual"
)
plt = Plots.plot(p1, p2; size=(img_height*2,img_height), layout=(1,2), kwrgs...)
plts = [p1]
for t in targets
ce = generate_counterfactual(
x_fact, t, data, model, generator;
decision_threshold=γ, max_iter=T,
initialization=init,
converge_when=:generator_conditions,
)
_x = CounterfactualExplanations.counterfactual(ce)
plt = Plots.plot(
convert2image(MNIST, reshape(_x,28,28)),
axis=([], false),
size=(img_height, img_height),
title="$(label_names[t+1])",
dpi=300
)
plts = [plts..., plt]
end
plt = Plots.plot(plts..., size=(img_height*length(plts),img_height), layout=(1,length(plts)), kwrgs...)
return plt
end
......@@ -235,7 +253,27 @@ end
```{julia}
rng = rand(1:10000)
factual = 3
plt = plot_fmnist(3, 1, rng=rng)
```
\ No newline at end of file
rng = 42
gen = ECCCoGenerator(
λ=[0.1,0.30,0.30],
nsamples=25,
nmin=10,
)
plts = []
plt = plot_fmnist(3, [1]; rng=rng, generator=gen, img_height=200)
display(plt)
push!(plts, plt)
savefig(plt, "dev/rebuttal/www/fmnist_dress.png")
plt = plot_fmnist(9, [5]; rng=rng, generator=gen, img_height=200)
display(plt)
push!(plts, plt)
savefig(plt, "dev/rebuttal/www/fmnist_boot.png")
plt = plot_fmnist(2, [4]; rng=rng, generator=gen, img_height=200)
display(plt)
push!(plts, plt)
savefig(plt, "dev/rebuttal/www/fmnist_pullover.png")
```
......@@ -613,9 +613,9 @@ display(plt_additional_models)
```
```{julia}
Random.seed!(2023)
Random.seed!(123)
λ = [0.1,0.215,0.215]
λ = [0.1,0.22,0.22]
wachter = WachterGenerator(
λ=λ[1],
......@@ -624,13 +624,18 @@ wachter = WachterGenerator(
combos = [
(9,7),
(7,2),
(6,0),
(4,1),
(2,3),
(5,8),
(6,5),
(1,7),
(1,4),
(0,3)
]
n_each = 2
n_each = 3
combos = reduce(vcat, [fill(c, n_each) for c in combos])
plts_eccco = []
......@@ -684,7 +689,7 @@ for (i, (factual, target)) in enumerate(combos)
p1 = plts_eccco[i]
p2 = plts_wachter[i]
plt = Plots.plot(
p1, p2, layout=(2,1), plot_title="Factual: $factual, Target: $target",
p1, p2, layout=(2,1),
size = (1200, 400),
)
display(plt)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment