diff --git a/artifacts/results/images/mnist_eccco.png b/artifacts/results/images/mnist_eccco.png new file mode 100644 index 0000000000000000000000000000000000000000..8f18c4d1c8fedef587e36d2d0e59fc724fa840de Binary files /dev/null and b/artifacts/results/images/mnist_eccco.png differ diff --git a/artifacts/results/images/mnist_generated_JEM Ensemble.png b/artifacts/results/images/mnist_generated_JEM Ensemble.png index 015cddd10299c4734b5faffcce4585b0582150a4..a2d8a81bf78e3778ae2c9b68aae775f5639f74a7 100644 Binary files a/artifacts/results/images/mnist_generated_JEM Ensemble.png and b/artifacts/results/images/mnist_generated_JEM Ensemble.png differ diff --git a/artifacts/results/images/mnist_generated_JEM.png b/artifacts/results/images/mnist_generated_JEM.png index 6ff8e5780dae3cd4ac3c775c7faf31801ccc6f5a..db4e965990f722732a69f3b92f3c0bb2264cf43d 100644 Binary files a/artifacts/results/images/mnist_generated_JEM.png and b/artifacts/results/images/mnist_generated_JEM.png differ diff --git a/artifacts/results/images/mnist_generated_MLP Ensemble.png b/artifacts/results/images/mnist_generated_MLP Ensemble.png index 4834dde4916fd890c7b13523148ffee9405ef0c7..02468de5773695e05a723fdbe08541ed31a552b0 100644 Binary files a/artifacts/results/images/mnist_generated_MLP Ensemble.png and b/artifacts/results/images/mnist_generated_MLP Ensemble.png differ diff --git a/artifacts/results/images/mnist_generated_MLP.png b/artifacts/results/images/mnist_generated_MLP.png index 02e86da900666b6c564655448a764fc347228264..029c632767e0e4b1d9ee14efcd7b2a7855322e88 100644 Binary files a/artifacts/results/images/mnist_generated_MLP.png and b/artifacts/results/images/mnist_generated_MLP.png differ diff --git a/artifacts/results/images/surrogate_gone_wrong.png b/artifacts/results/images/surrogate_gone_wrong.png new file mode 100644 index 0000000000000000000000000000000000000000..d0904be0d5275cde9d790d7fdeae672feef54c6c Binary files /dev/null and b/artifacts/results/images/surrogate_gone_wrong.png differ diff --git a/artifacts/results/images/you_may_not_like_it.png b/artifacts/results/images/you_may_not_like_it.png new file mode 100644 index 0000000000000000000000000000000000000000..87abce43d2b3ffef3307ebbc04610564853b932a Binary files /dev/null and b/artifacts/results/images/you_may_not_like_it.png differ diff --git a/artifacts/results/mnist_model_performance.csv b/artifacts/results/mnist_model_performance.csv index 19ef85f908ff9b49191bb26f1d7f5b9286991176..61dd7f8de0d3e2bc61fe9523b8b4717dc9e9b0ea 100644 --- a/artifacts/results/mnist_model_performance.csv +++ b/artifacts/results/mnist_model_performance.csv @@ -1,5 +1,5 @@ acc,precision,f1score,mod_name -0.91,0.9103128613256043,0.9085266898485427,JEM Ensemble -0.9423,0.9418035808913033,0.9415726345973308,MLP -0.9439,0.943391966967219,0.9432116151207373,MLP Ensemble -0.8748,0.8798390712341672,0.8728379219312089,JEM +0.9101,0.9104599093362784,0.9086219445938086,JEM Ensemble +0.942,0.9415017061499791,0.9412674228675252,MLP +0.9441,0.9435986059071019,0.9434079268844547,MLP Ensemble +0.8752,0.8804593411643,0.8733208976945687,JEM diff --git a/artifacts/results/mnist_model_performance.jls b/artifacts/results/mnist_model_performance.jls index 9e6a885e0eeb8b869672d63905a5644cbde12436..f0a8759eb39e6ed0d858bd21fc462e4617e7ffcd 100644 Binary files a/artifacts/results/mnist_model_performance.jls and b/artifacts/results/mnist_model_performance.jls differ diff --git a/artifacts/results/mnist_vae.jls b/artifacts/results/mnist_vae.jls index 6167b3c118ef4be2c37a9a9a4f84af34b3c5d260..dbbf8c47445a74c4a539ffda042084fc3873bdd6 100644 Binary files a/artifacts/results/mnist_vae.jls and b/artifacts/results/mnist_vae.jls differ diff --git a/artifacts/results/mnist_vae_weak.jls b/artifacts/results/mnist_vae_weak.jls index b81b474a74b5ffd83b8650a49727326eb905b732..2c434b6fb6e97be2e245a955fe2d594e718a9aeb 100644 Binary files a/artifacts/results/mnist_vae_weak.jls and b/artifacts/results/mnist_vae_weak.jls differ diff --git a/notebooks/mnist.qmd b/notebooks/mnist.qmd index 7cf36c56db015edbd0be538cb7ee9e6f18922ba4..93d5d668d73d0895f08f6ae48f80c5611f0c7dff 100644 --- a/notebooks/mnist.qmd +++ b/notebooks/mnist.qmd @@ -12,6 +12,8 @@ eval(setup_notebooks) #### Wachter and JSMA ```{julia} +Random.seed!(1234) + # Data: counterfactual_data = load_mnist() X, y = CounterfactualExplanations.DataPreprocessing.unpack_data(counterfactual_data) @@ -20,25 +22,27 @@ M = load_mnist_mlp() # Target: factual_label = 8 -x = reshape(X[:,rand(findall(predict_label(M, counterfactual_data).==factual_label))],input_dim,1) +x_factual = reshape(X[:,rand(findall(predict_label(M, counterfactual_data).==factual_label))],input_dim,1) target = 3 -factual = predict_label(M, counterfactual_data, x)[1] +factual = predict_label(M, counterfactual_data, x_factual)[1] γ = 0.9 -T = 50 + +# Training params: +T = 100 +opt = Flux.Optimise.Adam(0.01) ``` ```{julia} # Search: -opt = Flux.Optimise.Adam(0.01) generator = GenericGenerator(opt=opt) ce_wachter = generate_counterfactual( - x, target, counterfactual_data, M, generator; + x_factual, target, counterfactual_data, M, generator; decision_threshold=γ, max_iter=T, initialization=:identity, ) generator = GreedyGenerator(η=1.0) ce_jsma = generate_counterfactual( - x, target, counterfactual_data, M, generator; + x_factual, target, counterfactual_data, M, generator; decision_threshold=γ, max_iter=T, initialization=:identity, ) @@ -46,7 +50,7 @@ ce_jsma = generate_counterfactual( ```{julia} p1 = Plots.plot( - convert2image(MNIST, reshape(x,28,28)), + convert2image(MNIST, reshape(x_factual,28,28)), axis=nothing, size=(img_height, img_height), title="Factual" @@ -69,7 +73,7 @@ for x in zip(eachslice(counterfactuals; dims=3), eachslice(phat; dims=3), ["Wach end plt = Plots.plot(plts...; size=(img_height*length(plts),img_height), layout=(1,length(plts))) display(plt) -savefig(plt, joinpath(www_path, "you_may_not_like_it.png")) +savefig(plt, joinpath(output_images_path, "you_may_not_like_it.png")) ``` #### REVISE @@ -87,21 +91,23 @@ Serialization.serialize(joinpath(output_path,"mnist_vae_weak.jls"), vae_weak) # Define generator: generator = REVISEGenerator( opt = opt, - λ=0.01 + λ=0.1 ) # Generate recourse: counterfactual_data.generative_model = vae # assign generative model ce_strong = generate_counterfactual( - x, target, counterfactual_data, M, generator; + x_factual, target, counterfactual_data, M, generator; decision_threshold=γ, max_iter=T, initialization=:identity, + converge_when=:generator_conditions, ) counterfactual_data = deepcopy(counterfactual_data) counterfactual_data.generative_model = vae_weak ce_weak = generate_counterfactual( - x, target, counterfactual_data, M, generator; + x_factual, target, counterfactual_data, M, generator; decision_threshold=γ, max_iter=T, initialization=:identity, + converge_when=:generator_conditions, ) ``` @@ -123,7 +129,7 @@ for x in zip(eachslice(counterfactuals; dims=3), eachslice(phat; dims=3), ["Stro end plt = Plots.plot(plts...; size=(img_height*length(plts),img_height), layout=(1,length(plts))) display(plt) -savefig(plt, joinpath(www_path, "surrogate_gone_wrong.png")) +savefig(plt, joinpath(output_images_path, "surrogate_gone_wrong.png")) ``` ### ECCCo @@ -137,12 +143,17 @@ end ``` ```{julia} +# Hyper: +_retrain = false +_regen = false + # Data: n_obs = 10000 counterfactual_data = load_mnist(n_obs) counterfactual_data.X = pre_process.(counterfactual_data.X) X, y = CounterfactualExplanations.DataPreprocessing.unpack_data(counterfactual_data) X = table(permutedims(X)) +x_factual = reshape(pre_process(x_factual, noise=0.0f0), input_dim, 1) labels = counterfactual_data.output_encoder.labels input_dim, n_obs = size(counterfactual_data.X) n_digits = Int(sqrt(input_dim)) @@ -237,44 +248,49 @@ function _train(model, X=X, y=labels; cov=.95, method=:simple_inductive, mod_nam M = ECCCo.ConformalModel(mach.model, mach.fitresult) return M end -model_dict = Dict(mod_name => _train(mod; mod_name=mod_name) for (mod_name, mod) in models) -Serialization.serialize(joinpath(output_path,"mnist_models.jls"), model_dict) +if _retrain + model_dict = Dict(mod_name => _train(mod; mod_name=mod_name) for (mod_name, mod) in models) + Serialization.serialize(joinpath(output_path,"mnist_models.jls"), model_dict) +else + model_dict = Serialization.deserialize(joinpath(output_path,"mnist_models.jls")) +end ``` ```{julia} # Plot generated samples: - -for (mod_name, mod) in model_dict - if ECCCo._has_sampler(mod) - sampler = ECCCo._get_sampler(mod) - else - K = length(counterfactual_data.y_levels) - input_size = size(selectdim(counterfactual_data.X, ndims(counterfactual_data.X), 1)) - ð’Ÿx = Uniform(extrema(counterfactual_data.X)...) - ð’Ÿy = Categorical(ones(K) ./ K) - sampler = ConditionalSampler(ð’Ÿx, ð’Ÿy; input_size=input_size) - end - opt = ImproperSGLD() - f(x) = logits(mod, x) - - n_iter = 200 - _w = 1500 - plts = [] - neach = 10 - for i in 1:10 - x = sampler(f, opt; niter=n_iter, n_samples=neach, y=i) - plts_i = [] - for j in 1:size(x, 2) - xj = x[:,j] - xj = reshape(xj, (n_digits, n_digits)) - plts_i = [plts_i..., Plots.heatmap(rotl90(xj), axis=nothing, cb=false)] +if _regen + for (mod_name, mod) in model_dict + if ECCCo._has_sampler(mod) + sampler = ECCCo._get_sampler(mod) + else + K = length(counterfactual_data.y_levels) + input_size = size(selectdim(counterfactual_data.X, ndims(counterfactual_data.X), 1)) + ð’Ÿx = Uniform(extrema(counterfactual_data.X)...) + ð’Ÿy = Categorical(ones(K) ./ K) + sampler = ConditionalSampler(ð’Ÿx, ð’Ÿy; input_size=input_size) end - plt = Plots.plot(plts_i..., size=(_w,0.10*_w), layout=(1,10)) - plts = [plts..., plt] + opt = ImproperSGLD() + f(x) = logits(mod, x) + + n_iter = 200 + _w = 1500 + plts = [] + neach = 10 + for i in 1:10 + x = sampler(f, opt; niter=n_iter, n_samples=neach, y=i) + plts_i = [] + for j in 1:size(x, 2) + xj = x[:,j] + xj = reshape(xj, (n_digits, n_digits)) + plts_i = [plts_i..., Plots.heatmap(rotl90(xj), axis=nothing, cb=false)] + end + plt = Plots.plot(plts_i..., size=(_w,0.10*_w), layout=(1,10)) + plts = [plts..., plt] + end + plt = Plots.plot(plts..., size=(_w,_w), layout=(10,1), plot_title=mod_name) + savefig(plt, joinpath(output_images_path, "mnist_generated_$(mod_name).png")) + display(plt) end - plt = Plots.plot(plts..., size=(_w,_w), layout=(10,1), plot_title=mod_name) - savefig(plt, joinpath(output_images_path, "mnist_generated_$(mod_name).png")) - display(plt) end ``` @@ -290,7 +306,7 @@ model_performance = DataFrame() for (mod_name, mod) in model_dict # Test performance: test_data = load_mnist_test() - test_data.X = pre_process.(test_data.X) + test_data.X = pre_process.(test_data.X, noise=0.0f0) _perf = CounterfactualExplanations.Models.model_evaluation(mod, test_data, measure=collect(values(measure))) _perf = DataFrame([[p] for p in _perf], collect(keys(measure))) _perf.mod_name .= mod_name @@ -302,15 +318,6 @@ model_performance ``` ```{julia} -Random.seed!(123) - -# Set up search: -factual = 8 -x = reshape(counterfactual_data.X[:,rand(findall(labels.==factual))],input_dim,1) -target = 3 -γ = 0.9 -T = 100 - # ECCCo: λ=[0.5,0.1,0.5] temp=0.1 @@ -326,37 +333,41 @@ generator = ECCCoGenerator( ces = Dict() for (mod_name, mod) in model_dict ce = generate_counterfactual( - x, target, counterfactual_data, mod, generator; + x_factual, target, counterfactual_data, mod, generator; decision_threshold=γ, max_iter=T, initialization=:identity, converge_when=:generator_conditions, ) ces[mod_name] = ce end +plt_order = sortperm(collect(keys(ces))) # Plot: p1 = Plots.plot( - convert2image(MNIST, reshape(x,28,28)), + convert2image(MNIST, reshape(x_factual,28,28)), axis=nothing, size=(img_height, img_height), title="Factual" ) -plts = [p1] +plts = [] for (_name,ce) in ces - x = CounterfactualExplanations.counterfactual(ce) + _x = CounterfactualExplanations.counterfactual(ce) _phat = target_probs(ce) _title = "$_name (pÌ‚=$(round(_phat[1]; digits=3)))" plt = Plots.plot( - convert2image(MNIST, reshape(x,28,28)), + convert2image(MNIST, reshape(_x,28,28)), axis=nothing, size=(img_height, img_height), title=_title ) plts = [plts..., plt] end +plts = plts[plt_order] +plts = [p1, plts...] plt = Plots.plot(plts...; size=(img_height*length(plts),img_height), layout=(1,length(plts))) display(plt) +savefig(plt, joinpath(output_images_path, "mnist_eccco.png")) ``` ```{julia} @@ -364,9 +375,9 @@ Random.seed!(1234) # Set up search: factual_label = 9 -x = reshape(counterfactual_data.X[:,rand(findall(predict_label(M, counterfactual_data).==factual_label))],input_dim,1) +x_factual = reshape(counterfactual_data.X[:,rand(findall(predict_label(M, counterfactual_data).==factual_label))],input_dim,1) target = 7 -factual = predict_label(M, counterfactual_data, x)[1] +factual = predict_label(M, counterfactual_data, x_factual)[1] γ = 0.9 T = 100 @@ -375,7 +386,7 @@ T = 100 # Generate counterfactual using generic generator: generator = GenericGenerator(opt=Flux.Optimise.Adam(0.01),) ce_wachter = generate_counterfactual( - x, target, counterfactual_data, M, generator; + x_factual, target, counterfactual_data, M, generator; decision_threshold=γ, max_iter=T, initialization=:identity, converge_when=:generator_conditions, @@ -383,7 +394,7 @@ ce_wachter = generate_counterfactual( generator = GreedyGenerator(η=η) ce_jsma = generate_counterfactual( - x, target, counterfactual_data, M, generator; + x_factual, target, counterfactual_data, M, generator; decision_threshold=γ, max_iter=T, initialization=:identity, converge_when=:generator_conditions, @@ -400,7 +411,7 @@ generator = ECCCoGenerator( opt=Flux.Optimise.Adam(0.01), ) ce_conformal = generate_counterfactual( - x, target, counterfactual_data, M, generator; + x_factual, target, counterfactual_data, M, generator; decision_threshold=γ, max_iter=T, initialization=:identity, converge_when=:generator_conditions, @@ -413,7 +424,7 @@ generator = ECCCoGenerator( opt=CounterfactualExplanations.Generators.JSMADescent(η=η), ) ce_conformal_jsma = generate_counterfactual( - x, target, counterfactual_data, M, generator; + x_factual, target, counterfactual_data, M, generator; decision_threshold=γ, max_iter=T, initialization=:identity, converge_when=:generator_conditions, @@ -421,7 +432,7 @@ ce_conformal_jsma = generate_counterfactual( # Plot: p1 = Plots.plot( - convert2image(MNIST, reshape(x,28,28)), + convert2image(MNIST, reshape(x_factual,28,28)), axis=nothing, size=(img_height, img_height), title="Factual" @@ -432,11 +443,11 @@ ces = [ce_wachter, ce_conformal, ce_jsma, ce_conformal_jsma] _names = ["Wachter", "ECCCo", "JSMA", "ECCCo-JSMA"] for x in zip(ces, _names) ce, _name = (x[1],x[2]) - x = CounterfactualExplanations.counterfactual(ce) + x_factual = CounterfactualExplanations.counterfactual(ce) _phat = target_probs(ce) _title = "$_name (pÌ‚=$(round(_phat[1]; digits=3)))" plt = Plots.plot( - convert2image(MNIST, reshape(x,28,28)), + convert2image(MNIST, reshape(x_factual,28,28)), axis=nothing, size=(img_height, img_height), title=_title @@ -453,23 +464,23 @@ savefig(plt, joinpath(www_path, "eccco_mnist.png")) # Set up search: factual_label = 8 -x = reshape(counterfactual_data.X[:,rand(findall(predict_label(M, counterfactual_data).==factual_label))],input_dim,1) +x_factual = reshape(counterfactual_data.X[:,rand(findall(predict_label(M, counterfactual_data).==factual_label))],input_dim,1) target = 3 -factual = predict_label(M, counterfactual_data, x)[1] +factual = predict_label(M, counterfactual_data, x_factual)[1] γ = 0.5 T = 100 # Generate counterfactual using generic generator: generator = GenericGenerator(opt=Flux.Optimise.Adam(),) ce_wachter = generate_counterfactual( - x, target, counterfactual_data, M, generator; + x_factual, target, counterfactual_data, M, generator; decision_threshold=γ, max_iter=T, initialization=:identity, ) generator = GreedyGenerator(η=1.0) ce_jsma = generate_counterfactual( - x, target, counterfactual_data, M, generator; + x_factual, target, counterfactual_data, M, generator; decision_threshold=γ, max_iter=T, initialization=:identity, ) @@ -485,7 +496,7 @@ generator = CCEGenerator( opt=Flux.Optimise.Adam(), ) ce_conformal = generate_counterfactual( - x, target, counterfactual_data, M, generator; + x_factual, target, counterfactual_data, M, generator; decision_threshold=γ, max_iter=T, initialization=:identity, converge_when=:generator_conditions, @@ -498,7 +509,7 @@ generator = CCEGenerator( opt=CounterfactualExplanations.Generators.JSMADescent(η=1.0), ) ce_conformal_jsma = generate_counterfactual( - x, target, counterfactual_data, M, generator; + x_factual, target, counterfactual_data, M, generator; decision_threshold=γ, max_iter=T, initialization=:identity, converge_when=:generator_conditions, @@ -506,7 +517,7 @@ ce_conformal_jsma = generate_counterfactual( # Plot: p1 = Plots.plot( - convert2image(MNIST, reshape(x,28,28)), + convert2image(MNIST, reshape(x_factual,28,28)), axis=nothing, size=(img_height, img_height), title="Factual" diff --git a/paper/paper.pdf b/paper/paper.pdf index 6948356686c7558b15c1d17fd67d8c7613b6f4da..828c613ca85465b91d18e1fc904c22c1cef04531 100644 Binary files a/paper/paper.pdf and b/paper/paper.pdf differ diff --git a/paper/paper.tex b/paper/paper.tex index fb8ddcdfc1de94d649676f5f0e78094419d70761..663707bc3de3fe959e99901b0341701381a86205 100644 --- a/paper/paper.tex +++ b/paper/paper.tex @@ -159,12 +159,12 @@ Surrogate models offer an obvious solution to achieve this objective. Unfortunat \centering \begin{minipage}[t]{0.45\textwidth} \centering - \includegraphics[width=\textwidth]{../www/you_may_not_like_it.png} + \includegraphics[width=\textwidth]{../artifacts/results/images/you_may_not_like_it.png} \caption{You may not like it, but this is what stripped-down counterfactuals look like. Counterfactuals for turning an 8 (eight) into a 3 (three): original image (left); counterfactual produced using \citet{wachter2017counterfactual} (centre); and a counterfactual produced using JSMA-based approach introduced by \citep{schut2021generating}.}\label{fig:adv} \end{minipage}\hfill \begin{minipage}[t]{0.45\textwidth} \centering - \includegraphics[width=\textwidth]{../www/surrogate_gone_wrong.png} + \includegraphics[width=\textwidth]{../artifacts/results/images/surrogate_gone_wrong.png} \caption{Using surrogates can improve plausibility, but also increases vulnerability. Counterfactuals for turning an 8 (eight) into a 3 (three): original image (left); counterfactual produced using REVISE \citep{joshi2019realistic} with a well-specified surrogate (centre); and a counterfactual produced using REVISE \citep{joshi2019realistic} with a poorly specified surrogate (right).}\label{fig:vae} \end{minipage} \end{figure} @@ -217,6 +217,11 @@ Our framework for generating ECCCos combines the ideas introduced in the previou \end{aligned} \end{equation} +\begin{figure} + \includegraphics[width=\textwidth]{../artifacts/results/images/mnist_eccco.png} + \caption{ECCCos from Black Boxes. Counterfactuals for turning an 8 (eight) into a 3 (three): original image (left); }\label{fig:eccco} +\end{figure} + \section{Evaluation Framework}\label{conformity} In Section~\ref{background} we explained that Counterfactual Explanations work directly with Black Box Model, so fidelity is not a concern. This may explain why research has primarily focused on other desiderata, most notably plausibility (Definition~\ref{def:plausible}). Enquiring about the plausibility of a counterfactual essentially boils down to the following question: `Is this counterfactual consistent with the underlying data'? To introduce this section, we posit a related, slightly more nuanced question: `Is this counterfactual consistent with what the model has learned about the underlying data'? We will argue that fidelity is not a sufficient evaluation measure to answer this question and propose a novel way to assess if explanations conform with model behaviour. Finally, we will introduce a framework for Conformal Counterfactual Explanations, that reconciles the notions of plausibility and model conformity. @@ -268,6 +273,8 @@ As noted by \citet{guidotti2022counterfactual}, these distance-based measures ar \item It seems that models that are not explicitly trained for generative task, still learn it implictly \item Batch size seems to impact quality of generated samples (at inference, but not so much during JEM training) \item ECCCo is sensitive to optimizer (Adam works well), learning rate and distance metric (l1 works well) + \item SGLD takes time + \item REVISE has benefit of lower dimensional space \end{itemize} \section{Discussion} diff --git a/src/penalties.jl b/src/penalties.jl index d3fdf8e15e6ba8703e5e7dc2cd8583f4a278713e..19f06ae48de244bc691df38268ec92322554b139 100644 --- a/src/penalties.jl +++ b/src/penalties.jl @@ -37,7 +37,7 @@ end function distance_from_energy( ce::AbstractCounterfactualExplanation; - n::Int=10, niter=200, from_buffer=true, agg=mean, kwargs... + n::Int=10, niter=250, from_buffer=true, agg=mean, kwargs... ) conditional_samples = [] ignore_derivatives() do