Skip to content
Snippets Groups Projects
Commit 8f5924f7 authored by Pat Alt's avatar Pat Alt
Browse files

work on mnist example

parent 2c5ab5ae
No related branches found
No related tags found
No related merge requests found
Showing
with 100 additions and 82 deletions
artifacts/results/images/mnist_eccco.png

18.7 KiB

artifacts/results/images/mnist_generated_JEM Ensemble.png

159 KiB | W: | H:

artifacts/results/images/mnist_generated_JEM Ensemble.png

150 KiB | W: | H:

artifacts/results/images/mnist_generated_JEM Ensemble.png
artifacts/results/images/mnist_generated_JEM Ensemble.png
artifacts/results/images/mnist_generated_JEM Ensemble.png
artifacts/results/images/mnist_generated_JEM Ensemble.png
  • 2-up
  • Swipe
  • Onion skin
artifacts/results/images/mnist_generated_JEM.png

321 KiB | W: | H:

artifacts/results/images/mnist_generated_JEM.png

327 KiB | W: | H:

artifacts/results/images/mnist_generated_JEM.png
artifacts/results/images/mnist_generated_JEM.png
artifacts/results/images/mnist_generated_JEM.png
artifacts/results/images/mnist_generated_JEM.png
  • 2-up
  • Swipe
  • Onion skin
artifacts/results/images/mnist_generated_MLP Ensemble.png

335 KiB | W: | H:

artifacts/results/images/mnist_generated_MLP Ensemble.png

334 KiB | W: | H:

artifacts/results/images/mnist_generated_MLP Ensemble.png
artifacts/results/images/mnist_generated_MLP Ensemble.png
artifacts/results/images/mnist_generated_MLP Ensemble.png
artifacts/results/images/mnist_generated_MLP Ensemble.png
  • 2-up
  • Swipe
  • Onion skin
artifacts/results/images/mnist_generated_MLP.png

325 KiB | W: | H:

artifacts/results/images/mnist_generated_MLP.png

322 KiB | W: | H:

artifacts/results/images/mnist_generated_MLP.png
artifacts/results/images/mnist_generated_MLP.png
artifacts/results/images/mnist_generated_MLP.png
artifacts/results/images/mnist_generated_MLP.png
  • 2-up
  • Swipe
  • Onion skin
artifacts/results/images/surrogate_gone_wrong.png

14.5 KiB

artifacts/results/images/you_may_not_like_it.png

12.7 KiB

acc,precision,f1score,mod_name
0.91,0.9103128613256043,0.9085266898485427,JEM Ensemble
0.9423,0.9418035808913033,0.9415726345973308,MLP
0.9439,0.943391966967219,0.9432116151207373,MLP Ensemble
0.8748,0.8798390712341672,0.8728379219312089,JEM
0.9101,0.9104599093362784,0.9086219445938086,JEM Ensemble
0.942,0.9415017061499791,0.9412674228675252,MLP
0.9441,0.9435986059071019,0.9434079268844547,MLP Ensemble
0.8752,0.8804593411643,0.8733208976945687,JEM
No preview for this file type
No preview for this file type
No preview for this file type
......@@ -12,6 +12,8 @@ eval(setup_notebooks)
#### Wachter and JSMA
```{julia}
Random.seed!(1234)
# Data:
counterfactual_data = load_mnist()
X, y = CounterfactualExplanations.DataPreprocessing.unpack_data(counterfactual_data)
......@@ -20,25 +22,27 @@ M = load_mnist_mlp()
# Target:
factual_label = 8
x = reshape(X[:,rand(findall(predict_label(M, counterfactual_data).==factual_label))],input_dim,1)
x_factual = reshape(X[:,rand(findall(predict_label(M, counterfactual_data).==factual_label))],input_dim,1)
target = 3
factual = predict_label(M, counterfactual_data, x)[1]
factual = predict_label(M, counterfactual_data, x_factual)[1]
γ = 0.9
T = 50
# Training params:
T = 100
opt = Flux.Optimise.Adam(0.01)
```
```{julia}
# Search:
opt = Flux.Optimise.Adam(0.01)
generator = GenericGenerator(opt=opt)
ce_wachter = generate_counterfactual(
x, target, counterfactual_data, M, generator;
x_factual, target, counterfactual_data, M, generator;
decision_threshold=γ, max_iter=T,
initialization=:identity,
)
generator = GreedyGenerator(η=1.0)
ce_jsma = generate_counterfactual(
x, target, counterfactual_data, M, generator;
x_factual, target, counterfactual_data, M, generator;
decision_threshold=γ, max_iter=T,
initialization=:identity,
)
......@@ -46,7 +50,7 @@ ce_jsma = generate_counterfactual(
```{julia}
p1 = Plots.plot(
convert2image(MNIST, reshape(x,28,28)),
convert2image(MNIST, reshape(x_factual,28,28)),
axis=nothing,
size=(img_height, img_height),
title="Factual"
......@@ -69,7 +73,7 @@ for x in zip(eachslice(counterfactuals; dims=3), eachslice(phat; dims=3), ["Wach
end
plt = Plots.plot(plts...; size=(img_height*length(plts),img_height), layout=(1,length(plts)))
display(plt)
savefig(plt, joinpath(www_path, "you_may_not_like_it.png"))
savefig(plt, joinpath(output_images_path, "you_may_not_like_it.png"))
```
#### REVISE
......@@ -87,21 +91,23 @@ Serialization.serialize(joinpath(output_path,"mnist_vae_weak.jls"), vae_weak)
# Define generator:
generator = REVISEGenerator(
opt = opt,
λ=0.01
λ=0.1
)
# Generate recourse:
counterfactual_data.generative_model = vae # assign generative model
ce_strong = generate_counterfactual(
x, target, counterfactual_data, M, generator;
x_factual, target, counterfactual_data, M, generator;
decision_threshold=γ, max_iter=T,
initialization=:identity,
converge_when=:generator_conditions,
)
counterfactual_data = deepcopy(counterfactual_data)
counterfactual_data.generative_model = vae_weak
ce_weak = generate_counterfactual(
x, target, counterfactual_data, M, generator;
x_factual, target, counterfactual_data, M, generator;
decision_threshold=γ, max_iter=T,
initialization=:identity,
converge_when=:generator_conditions,
)
```
......@@ -123,7 +129,7 @@ for x in zip(eachslice(counterfactuals; dims=3), eachslice(phat; dims=3), ["Stro
end
plt = Plots.plot(plts...; size=(img_height*length(plts),img_height), layout=(1,length(plts)))
display(plt)
savefig(plt, joinpath(www_path, "surrogate_gone_wrong.png"))
savefig(plt, joinpath(output_images_path, "surrogate_gone_wrong.png"))
```
### ECCCo
......@@ -137,12 +143,17 @@ end
```
```{julia}
# Hyper:
_retrain = false
_regen = false
# Data:
n_obs = 10000
counterfactual_data = load_mnist(n_obs)
counterfactual_data.X = pre_process.(counterfactual_data.X)
X, y = CounterfactualExplanations.DataPreprocessing.unpack_data(counterfactual_data)
X = table(permutedims(X))
x_factual = reshape(pre_process(x_factual, noise=0.0f0), input_dim, 1)
labels = counterfactual_data.output_encoder.labels
input_dim, n_obs = size(counterfactual_data.X)
n_digits = Int(sqrt(input_dim))
......@@ -237,44 +248,49 @@ function _train(model, X=X, y=labels; cov=.95, method=:simple_inductive, mod_nam
M = ECCCo.ConformalModel(mach.model, mach.fitresult)
return M
end
model_dict = Dict(mod_name => _train(mod; mod_name=mod_name) for (mod_name, mod) in models)
Serialization.serialize(joinpath(output_path,"mnist_models.jls"), model_dict)
if _retrain
model_dict = Dict(mod_name => _train(mod; mod_name=mod_name) for (mod_name, mod) in models)
Serialization.serialize(joinpath(output_path,"mnist_models.jls"), model_dict)
else
model_dict = Serialization.deserialize(joinpath(output_path,"mnist_models.jls"))
end
```
```{julia}
# Plot generated samples:
for (mod_name, mod) in model_dict
if ECCCo._has_sampler(mod)
sampler = ECCCo._get_sampler(mod)
else
K = length(counterfactual_data.y_levels)
input_size = size(selectdim(counterfactual_data.X, ndims(counterfactual_data.X), 1))
𝒟x = Uniform(extrema(counterfactual_data.X)...)
𝒟y = Categorical(ones(K) ./ K)
sampler = ConditionalSampler(𝒟x, 𝒟y; input_size=input_size)
end
opt = ImproperSGLD()
f(x) = logits(mod, x)
n_iter = 200
_w = 1500
plts = []
neach = 10
for i in 1:10
x = sampler(f, opt; niter=n_iter, n_samples=neach, y=i)
plts_i = []
for j in 1:size(x, 2)
xj = x[:,j]
xj = reshape(xj, (n_digits, n_digits))
plts_i = [plts_i..., Plots.heatmap(rotl90(xj), axis=nothing, cb=false)]
if _regen
for (mod_name, mod) in model_dict
if ECCCo._has_sampler(mod)
sampler = ECCCo._get_sampler(mod)
else
K = length(counterfactual_data.y_levels)
input_size = size(selectdim(counterfactual_data.X, ndims(counterfactual_data.X), 1))
𝒟x = Uniform(extrema(counterfactual_data.X)...)
𝒟y = Categorical(ones(K) ./ K)
sampler = ConditionalSampler(𝒟x, 𝒟y; input_size=input_size)
end
plt = Plots.plot(plts_i..., size=(_w,0.10*_w), layout=(1,10))
plts = [plts..., plt]
opt = ImproperSGLD()
f(x) = logits(mod, x)
n_iter = 200
_w = 1500
plts = []
neach = 10
for i in 1:10
x = sampler(f, opt; niter=n_iter, n_samples=neach, y=i)
plts_i = []
for j in 1:size(x, 2)
xj = x[:,j]
xj = reshape(xj, (n_digits, n_digits))
plts_i = [plts_i..., Plots.heatmap(rotl90(xj), axis=nothing, cb=false)]
end
plt = Plots.plot(plts_i..., size=(_w,0.10*_w), layout=(1,10))
plts = [plts..., plt]
end
plt = Plots.plot(plts..., size=(_w,_w), layout=(10,1), plot_title=mod_name)
savefig(plt, joinpath(output_images_path, "mnist_generated_$(mod_name).png"))
display(plt)
end
plt = Plots.plot(plts..., size=(_w,_w), layout=(10,1), plot_title=mod_name)
savefig(plt, joinpath(output_images_path, "mnist_generated_$(mod_name).png"))
display(plt)
end
```
......@@ -290,7 +306,7 @@ model_performance = DataFrame()
for (mod_name, mod) in model_dict
# Test performance:
test_data = load_mnist_test()
test_data.X = pre_process.(test_data.X)
test_data.X = pre_process.(test_data.X, noise=0.0f0)
_perf = CounterfactualExplanations.Models.model_evaluation(mod, test_data, measure=collect(values(measure)))
_perf = DataFrame([[p] for p in _perf], collect(keys(measure)))
_perf.mod_name .= mod_name
......@@ -302,15 +318,6 @@ model_performance
```
```{julia}
Random.seed!(123)
# Set up search:
factual = 8
x = reshape(counterfactual_data.X[:,rand(findall(labels.==factual))],input_dim,1)
target = 3
γ = 0.9
T = 100
# ECCCo:
λ=[0.5,0.1,0.5]
temp=0.1
......@@ -326,37 +333,41 @@ generator = ECCCoGenerator(
ces = Dict()
for (mod_name, mod) in model_dict
ce = generate_counterfactual(
x, target, counterfactual_data, mod, generator;
x_factual, target, counterfactual_data, mod, generator;
decision_threshold=γ, max_iter=T,
initialization=:identity,
converge_when=:generator_conditions,
)
ces[mod_name] = ce
end
plt_order = sortperm(collect(keys(ces)))
# Plot:
p1 = Plots.plot(
convert2image(MNIST, reshape(x,28,28)),
convert2image(MNIST, reshape(x_factual,28,28)),
axis=nothing,
size=(img_height, img_height),
title="Factual"
)
plts = [p1]
plts = []
for (_name,ce) in ces
x = CounterfactualExplanations.counterfactual(ce)
_x = CounterfactualExplanations.counterfactual(ce)
_phat = target_probs(ce)
_title = "$_name (p̂=$(round(_phat[1]; digits=3)))"
plt = Plots.plot(
convert2image(MNIST, reshape(x,28,28)),
convert2image(MNIST, reshape(_x,28,28)),
axis=nothing,
size=(img_height, img_height),
title=_title
)
plts = [plts..., plt]
end
plts = plts[plt_order]
plts = [p1, plts...]
plt = Plots.plot(plts...; size=(img_height*length(plts),img_height), layout=(1,length(plts)))
display(plt)
savefig(plt, joinpath(output_images_path, "mnist_eccco.png"))
```
```{julia}
......@@ -364,9 +375,9 @@ Random.seed!(1234)
# Set up search:
factual_label = 9
x = reshape(counterfactual_data.X[:,rand(findall(predict_label(M, counterfactual_data).==factual_label))],input_dim,1)
x_factual = reshape(counterfactual_data.X[:,rand(findall(predict_label(M, counterfactual_data).==factual_label))],input_dim,1)
target = 7
factual = predict_label(M, counterfactual_data, x)[1]
factual = predict_label(M, counterfactual_data, x_factual)[1]
γ = 0.9
T = 100
......@@ -375,7 +386,7 @@ T = 100
# Generate counterfactual using generic generator:
generator = GenericGenerator(opt=Flux.Optimise.Adam(0.01),)
ce_wachter = generate_counterfactual(
x, target, counterfactual_data, M, generator;
x_factual, target, counterfactual_data, M, generator;
decision_threshold=γ, max_iter=T,
initialization=:identity,
converge_when=:generator_conditions,
......@@ -383,7 +394,7 @@ ce_wachter = generate_counterfactual(
generator = GreedyGenerator(η=η)
ce_jsma = generate_counterfactual(
x, target, counterfactual_data, M, generator;
x_factual, target, counterfactual_data, M, generator;
decision_threshold=γ, max_iter=T,
initialization=:identity,
converge_when=:generator_conditions,
......@@ -400,7 +411,7 @@ generator = ECCCoGenerator(
opt=Flux.Optimise.Adam(0.01),
)
ce_conformal = generate_counterfactual(
x, target, counterfactual_data, M, generator;
x_factual, target, counterfactual_data, M, generator;
decision_threshold=γ, max_iter=T,
initialization=:identity,
converge_when=:generator_conditions,
......@@ -413,7 +424,7 @@ generator = ECCCoGenerator(
opt=CounterfactualExplanations.Generators.JSMADescent(η=η),
)
ce_conformal_jsma = generate_counterfactual(
x, target, counterfactual_data, M, generator;
x_factual, target, counterfactual_data, M, generator;
decision_threshold=γ, max_iter=T,
initialization=:identity,
converge_when=:generator_conditions,
......@@ -421,7 +432,7 @@ ce_conformal_jsma = generate_counterfactual(
# Plot:
p1 = Plots.plot(
convert2image(MNIST, reshape(x,28,28)),
convert2image(MNIST, reshape(x_factual,28,28)),
axis=nothing,
size=(img_height, img_height),
title="Factual"
......@@ -432,11 +443,11 @@ ces = [ce_wachter, ce_conformal, ce_jsma, ce_conformal_jsma]
_names = ["Wachter", "ECCCo", "JSMA", "ECCCo-JSMA"]
for x in zip(ces, _names)
ce, _name = (x[1],x[2])
x = CounterfactualExplanations.counterfactual(ce)
x_factual = CounterfactualExplanations.counterfactual(ce)
_phat = target_probs(ce)
_title = "$_name (p̂=$(round(_phat[1]; digits=3)))"
plt = Plots.plot(
convert2image(MNIST, reshape(x,28,28)),
convert2image(MNIST, reshape(x_factual,28,28)),
axis=nothing,
size=(img_height, img_height),
title=_title
......@@ -453,23 +464,23 @@ savefig(plt, joinpath(www_path, "eccco_mnist.png"))
# Set up search:
factual_label = 8
x = reshape(counterfactual_data.X[:,rand(findall(predict_label(M, counterfactual_data).==factual_label))],input_dim,1)
x_factual = reshape(counterfactual_data.X[:,rand(findall(predict_label(M, counterfactual_data).==factual_label))],input_dim,1)
target = 3
factual = predict_label(M, counterfactual_data, x)[1]
factual = predict_label(M, counterfactual_data, x_factual)[1]
γ = 0.5
T = 100
# Generate counterfactual using generic generator:
generator = GenericGenerator(opt=Flux.Optimise.Adam(),)
ce_wachter = generate_counterfactual(
x, target, counterfactual_data, M, generator;
x_factual, target, counterfactual_data, M, generator;
decision_threshold=γ, max_iter=T,
initialization=:identity,
)
generator = GreedyGenerator(η=1.0)
ce_jsma = generate_counterfactual(
x, target, counterfactual_data, M, generator;
x_factual, target, counterfactual_data, M, generator;
decision_threshold=γ, max_iter=T,
initialization=:identity,
)
......@@ -485,7 +496,7 @@ generator = CCEGenerator(
opt=Flux.Optimise.Adam(),
)
ce_conformal = generate_counterfactual(
x, target, counterfactual_data, M, generator;
x_factual, target, counterfactual_data, M, generator;
decision_threshold=γ, max_iter=T,
initialization=:identity,
converge_when=:generator_conditions,
......@@ -498,7 +509,7 @@ generator = CCEGenerator(
opt=CounterfactualExplanations.Generators.JSMADescent(η=1.0),
)
ce_conformal_jsma = generate_counterfactual(
x, target, counterfactual_data, M, generator;
x_factual, target, counterfactual_data, M, generator;
decision_threshold=γ, max_iter=T,
initialization=:identity,
converge_when=:generator_conditions,
......@@ -506,7 +517,7 @@ ce_conformal_jsma = generate_counterfactual(
# Plot:
p1 = Plots.plot(
convert2image(MNIST, reshape(x,28,28)),
convert2image(MNIST, reshape(x_factual,28,28)),
axis=nothing,
size=(img_height, img_height),
title="Factual"
......
No preview for this file type
......@@ -159,12 +159,12 @@ Surrogate models offer an obvious solution to achieve this objective. Unfortunat
\centering
\begin{minipage}[t]{0.45\textwidth}
\centering
\includegraphics[width=\textwidth]{../www/you_may_not_like_it.png}
\includegraphics[width=\textwidth]{../artifacts/results/images/you_may_not_like_it.png}
\caption{You may not like it, but this is what stripped-down counterfactuals look like. Counterfactuals for turning an 8 (eight) into a 3 (three): original image (left); counterfactual produced using \citet{wachter2017counterfactual} (centre); and a counterfactual produced using JSMA-based approach introduced by \citep{schut2021generating}.}\label{fig:adv}
\end{minipage}\hfill
\begin{minipage}[t]{0.45\textwidth}
\centering
\includegraphics[width=\textwidth]{../www/surrogate_gone_wrong.png}
\includegraphics[width=\textwidth]{../artifacts/results/images/surrogate_gone_wrong.png}
\caption{Using surrogates can improve plausibility, but also increases vulnerability. Counterfactuals for turning an 8 (eight) into a 3 (three): original image (left); counterfactual produced using REVISE \citep{joshi2019realistic} with a well-specified surrogate (centre); and a counterfactual produced using REVISE \citep{joshi2019realistic} with a poorly specified surrogate (right).}\label{fig:vae}
\end{minipage}
\end{figure}
......@@ -217,6 +217,11 @@ Our framework for generating ECCCos combines the ideas introduced in the previou
\end{aligned}
\end{equation}
\begin{figure}
\includegraphics[width=\textwidth]{../artifacts/results/images/mnist_eccco.png}
\caption{ECCCos from Black Boxes. Counterfactuals for turning an 8 (eight) into a 3 (three): original image (left); }\label{fig:eccco}
\end{figure}
\section{Evaluation Framework}\label{conformity}
In Section~\ref{background} we explained that Counterfactual Explanations work directly with Black Box Model, so fidelity is not a concern. This may explain why research has primarily focused on other desiderata, most notably plausibility (Definition~\ref{def:plausible}). Enquiring about the plausibility of a counterfactual essentially boils down to the following question: `Is this counterfactual consistent with the underlying data'? To introduce this section, we posit a related, slightly more nuanced question: `Is this counterfactual consistent with what the model has learned about the underlying data'? We will argue that fidelity is not a sufficient evaluation measure to answer this question and propose a novel way to assess if explanations conform with model behaviour. Finally, we will introduce a framework for Conformal Counterfactual Explanations, that reconciles the notions of plausibility and model conformity.
......@@ -268,6 +273,8 @@ As noted by \citet{guidotti2022counterfactual}, these distance-based measures ar
\item It seems that models that are not explicitly trained for generative task, still learn it implictly
\item Batch size seems to impact quality of generated samples (at inference, but not so much during JEM training)
\item ECCCo is sensitive to optimizer (Adam works well), learning rate and distance metric (l1 works well)
\item SGLD takes time
\item REVISE has benefit of lower dimensional space
\end{itemize}
\section{Discussion}
......
......@@ -37,7 +37,7 @@ end
function distance_from_energy(
ce::AbstractCounterfactualExplanation;
n::Int=10, niter=200, from_buffer=true, agg=mean, kwargs...
n::Int=10, niter=250, from_buffer=true, agg=mean, kwargs...
)
conditional_samples = []
ignore_derivatives() do
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment