diff --git a/artifacts/results/images/mnist_eccco.png b/artifacts/results/images/mnist_eccco.png index 217094c975213d0afc36922c9b9d0d1653d1b7d4..4d91b401d4b88aaeaae8660904fe687f128f74b3 100644 Binary files a/artifacts/results/images/mnist_eccco.png and b/artifacts/results/images/mnist_eccco.png differ diff --git a/artifacts/results/images/mnist_eccco_benchmark.png b/artifacts/results/images/mnist_eccco_benchmark.png index 9bf388b70297582d15f3accc2b6da83d46e512f1..49b590bc15223b0bf69a600c0648a91cc578c048 100644 Binary files a/artifacts/results/images/mnist_eccco_benchmark.png and b/artifacts/results/images/mnist_eccco_benchmark.png differ diff --git a/artifacts/results/images/mnist_generated_JEM Ensemble.png b/artifacts/results/images/mnist_generated_JEM Ensemble.png index f222cade2979638496bfc77d811b53dd850a4ed3..e8fa0be086cc9fdb9e01619379071931a0271688 100644 Binary files a/artifacts/results/images/mnist_generated_JEM Ensemble.png and b/artifacts/results/images/mnist_generated_JEM Ensemble.png differ diff --git a/artifacts/results/images/mnist_generated_JEM.png b/artifacts/results/images/mnist_generated_JEM.png index 28a1647f4120deb6eafcdbe88dd653aa9033c90d..cbf5a86c5c451ef194741f77cedb9212974966aa 100644 Binary files a/artifacts/results/images/mnist_generated_JEM.png and b/artifacts/results/images/mnist_generated_JEM.png differ diff --git a/artifacts/results/images/mnist_generated_MLP Ensemble.png b/artifacts/results/images/mnist_generated_MLP Ensemble.png index 7a21fab7a509b8731c85f90db7da61427686afd0..262854214927e3210605c6ed52e52ce9620c7cc9 100644 Binary files a/artifacts/results/images/mnist_generated_MLP Ensemble.png and b/artifacts/results/images/mnist_generated_MLP Ensemble.png differ diff --git a/artifacts/results/images/mnist_generated_MLP.png b/artifacts/results/images/mnist_generated_MLP.png index 970e9c6879f11d61a593a0028ac72dad421282c4..bc1c4a5935f6491d81dbdab6b489d1b8ff377b77 100644 Binary files a/artifacts/results/images/mnist_generated_MLP.png and b/artifacts/results/images/mnist_generated_MLP.png differ diff --git a/artifacts/results/mnist_model_performance.csv b/artifacts/results/mnist_model_performance.csv index 64433b21889e5308f8754b191f7e1c3241792341..3d5211537cb0d8d56706dd534c040662b50ab33d 100644 --- a/artifacts/results/mnist_model_performance.csv +++ b/artifacts/results/mnist_model_performance.csv @@ -1,5 +1,5 @@ acc,precision,f1score,mod_name -0.9019,0.9013187021264487,0.9006731391569962,JEM Ensemble -0.9431,0.9429692537123295,0.942515960085936,MLP -0.9449,0.9443587962537836,0.9442779098066854,MLP Ensemble -0.8374,0.8446918423076417,0.8342455855332602,JEM +0.9261,0.9260228514407332,0.9251619158437216,JEM Ensemble +0.9419,0.9412531322556057,0.9411680671494926,MLP +0.9417,0.9412171617920995,0.9410463995816242,MLP Ensemble +0.8948,0.8974775842341151,0.8936533031460684,JEM diff --git a/artifacts/results/mnist_model_performance.jls b/artifacts/results/mnist_model_performance.jls index f18e2c3529f5cb8ffc839f5fb918032ef451b469..07e88baf19668fcb0fb43827d7b3c027d338c391 100644 Binary files a/artifacts/results/mnist_model_performance.jls and b/artifacts/results/mnist_model_performance.jls differ diff --git a/artifacts/results/mnist_models.jls b/artifacts/results/mnist_models.jls index bf6810b5397cc4a444ab90ae830a00c35c22d198..e7847ce66b9e28488066ba5fee81411b78b8bee6 100644 Binary files a/artifacts/results/mnist_models.jls and b/artifacts/results/mnist_models.jls differ diff --git a/artifacts/results/mnist_vae.jls b/artifacts/results/mnist_vae.jls index e571c3f2ac2db78575bfa61439e4da7b817bb3f7..b0e32960a27364c9cf284cd5d8409ec363dd5299 100644 Binary files a/artifacts/results/mnist_vae.jls and b/artifacts/results/mnist_vae.jls differ diff --git a/artifacts/results/mnist_vae_weak.jls b/artifacts/results/mnist_vae_weak.jls index 5717d5eeb1a7e905354727a7553c701f328a93be..d3db5f46b841781aeaa993c20653bae898dc0aad 100644 Binary files a/artifacts/results/mnist_vae_weak.jls and b/artifacts/results/mnist_vae_weak.jls differ diff --git a/notebooks/mnist.qmd b/notebooks/mnist.qmd index e8b52e2313bb4eed9641ade3ceef39b8925664fb..82d559fddefc8f283e66c3366ec4ff6fb9e9f48b 100644 --- a/notebooks/mnist.qmd +++ b/notebooks/mnist.qmd @@ -137,7 +137,6 @@ savefig(plt, joinpath(output_images_path, "surrogate_gone_wrong.png")) ```{julia} function pre_process(x; noise::Float32=0.03f0) ϵ = Float32.(randn(size(x)) * noise) - x = @.(2 * x - 1) x += ϵ return x end @@ -149,7 +148,7 @@ _retrain = false _regen = false # Data: -n_obs = 10000 +n_obs = 1000 counterfactual_data = load_mnist(n_obs) counterfactual_data.X = pre_process.(counterfactual_data.X) X, y = CounterfactualExplanations.DataPreprocessing.unpack_data(counterfactual_data) @@ -187,7 +186,7 @@ sampler = ConditionalSampler( input_size=(input_dim,), batch_size=10, ) -α = [1.0,1.0,1e-2] # penalty strengths +α = [1.0,1.0,1e-1] # penalty strengths ``` ```{julia} @@ -307,7 +306,6 @@ model_performance = DataFrame() for (mod_name, mod) in model_dict # Test performance: test_data = load_mnist_test() - test_data.X = pre_process.(test_data.X, noise=0.0f0) _perf = CounterfactualExplanations.Models.model_evaluation(mod, test_data, measure=collect(values(measure))) _perf = DataFrame([[p] for p in _perf], collect(keys(measure))) _perf.mod_name .= mod_name @@ -321,58 +319,132 @@ model_performance ### Different Models ```{julia} -plt_order = ["MLP", "MLP Ensemble", "JEM", "JEM Ensemble"] - -# ECCCo: -λ=[0.5,0.1,0.5] -temp=0.5 -η=0.01 - -# Generate counterfactuals using ECCCo generator: -eccco_generator = ECCCoGenerator( - λ=λ, - temp=temp, - opt=Flux.Optimise.Adam(η), +function _plot_eccco_mnist( + x::Union{AbstractArray, Int}=x_factual, target::Int=target; + λ=[0.1,0.1,0.1], + temp=0.1,η=0.01, + plt_order = ["MLP", "MLP Ensemble", "JEM", "JEM Ensemble"], + opt = Flux.Optimise.Adam(η), + rng::Union{Int,AbstractRNG}=Random.GLOBAL_RNG, ) -ces = Dict() -for (mod_name, mod) in model_dict - ce = generate_counterfactual( - x_factual, target, counterfactual_data, mod, eccco_generator; - decision_threshold=γ, max_iter=T, - initialization=:identity, - converge_when=:generator_conditions, + # Setup: + Random.seed!(rng) + if x isa Int + x = reshape(counterfactual_data.X[:,rand(findall(labels.==x))],input_dim,1) + end + + # Generate counterfactuals using ECCCo generator: + eccco_generator = ECCCoGenerator( + λ=λ, + temp=temp, + opt=opt, + ) + + ces = Dict() + for (mod_name, mod) in model_dict + ce = generate_counterfactual( + x, target, counterfactual_data, mod, eccco_generator; + decision_threshold=γ, max_iter=T, + initialization=:identity, + converge_when=:generator_conditions, + ) + ces[mod_name] = ce + end + _plt_order = map(x -> findall(collect(keys(model_dict)) .== x)[1], plt_order) + + # Plot: + p1 = Plots.plot( + convert2image(MNIST, reshape(x,28,28)), + axis=nothing, + size=(img_height, img_height), + title="Factual" ) - ces[mod_name] = ce + + plts = [] + for (_name,ce) in ces + _x = CounterfactualExplanations.counterfactual(ce) + _phat = target_probs(ce) + _title = "$_name (p̂=$(round(_phat[1]; digits=3)))" + plt = Plots.plot( + convert2image(MNIST, reshape(_x,28,28)), + axis=nothing, + size=(img_height, img_height), + title=_title + ) + plts = [plts..., plt] + end + plts = plts[_plt_order] + plts = [p1, plts...] + plt = Plots.plot(plts...; size=(img_height*length(plts),img_height), layout=(1,length(plts))) + + return plt, eccco_generator end -_plt_order = map(x -> findall(collect(keys(model_dict)) .== x)[1], plt_order) +``` -# Plot: -p1 = Plots.plot( - convert2image(MNIST, reshape(x_factual,28,28)), - axis=nothing, - size=(img_height, img_height), - title="Factual" +```{julia} +plt, eccco_generator = _plot_eccco_mnist() +display(plt) +savefig(plt, joinpath(output_images_path, "mnist_eccco.png")) +``` + +### All digits + +```{julia} +function plot_mnist( + factual::Int,target::Int; + generator::AbstractGenerator, + model::AbstractFittedModel=model_dict["JEM Ensemble"], + data::CounterfactualData=counterfactual_data, + rng::Union{Int,AbstractRNG}=Random.GLOBAL_RNG, + _plot_title::Bool=true, + kwargs..., ) + decision_threshold = !isdefined(kwargs, :decision_threshold) ? 0.9 : decision_threshold + max_iter = !isdefined(kwargs, :max_iter) ? 100 : max_iter + initialization = !isdefined(kwargs, :initialization) ? :identity : initialization + converge_when = !isdefined(kwargs, :converge_when) ? :generator_conditions : converge_when + + x = reshape(data.X[:,rand(findall(predict_label(model, data).==factual))],input_dim,1) + ce = generate_counterfactual( + x, target, data, model, generator; + decision_threshold=decision_threshold, max_iter=max_iter, + initialization=initialization, + converge_when=converge_when, + kwargs... + ) + + _title = _plot_title ? "$(factual) -> $(target)" : "" -plts = [] -for (_name,ce) in ces _x = CounterfactualExplanations.counterfactual(ce) - _phat = target_probs(ce) - _title = "$_name (p̂=$(round(_phat[1]; digits=3)))" plt = Plots.plot( convert2image(MNIST, reshape(_x,28,28)), axis=nothing, size=(img_height, img_height), title=_title ) - plts = [plts..., plt] + return plt +end +``` + +```{julia} +if _regen + function plot_all_digits(rng=1;verbose=true,kwargs...) + plts = [] + for i in 0:9 + for j in 0:9 + @info "Generating counterfactual for $(i) -> $(j)" + plt = plot_mnist(i,j;kwargs...,rng=rng) + !verbose || display(plt) + plts = [plts..., plt] + end + end + plt = Plots.plot(plts...; size=(img_height*10,img_height*10), layout=(10,10)) + return plt + end + plt = plot_all_digits(generator=eccco_generator) + savefig(plt, joinpath(output_images_path, "mnist_eccco_all_digits.png")) end -plts = plts[_plt_order] -plts = [p1, plts...] -plt = Plots.plot(plts...; size=(img_height*length(plts),img_height), layout=(1,length(plts))) -display(plt) -savefig(plt, joinpath(output_images_path, "mnist_eccco.png")) ``` ### Different Generators @@ -432,21 +504,12 @@ savefig(plt, joinpath(output_images_path, "mnist_eccco_benchmark.png")) ## Benchmark ```{julia} -# Benchmark generators: -generators = Dict( - :wachter => GenericGenerator(opt=opt, λ=l2_λ), - :revise => REVISEGenerator(opt=opt, λ=l2_λ), - :greedy => GreedyGenerator(), -) - -# Conformal Models: - - # Measures: measures = [ CounterfactualExplanations.distance, ECCCo.distance_from_energy, ECCCo.distance_from_targets, CounterfactualExplanations.validity, + CounterfactualExplanations.redudancy, ] ``` \ No newline at end of file diff --git a/paper/paper.pdf b/paper/paper.pdf index a954a5cb646cfbca83f79dc2416ace000beea094..de806dd9399e7a53bec610210a08b01ecdbc9aa9 100644 Binary files a/paper/paper.pdf and b/paper/paper.pdf differ diff --git a/src/generator.jl b/src/generator.jl index 72952aa31de89bf7d1a8d297a8337abe97cf9a16..87b1dceff3e45e92075efdfb5a69f614eb15e2af 100644 --- a/src/generator.jl +++ b/src/generator.jl @@ -28,7 +28,7 @@ function ECCCoGenerator(; function _set_size_penalty(ce::AbstractCounterfactualExplanation) return ECCCo.set_size_penalty(ce; κ=κ, temp=temp) end - _penalties = [Objectives.distance_l2, _set_size_penalty, ECCCo.distance_from_energy] + _penalties = [Objectives.distance_l1, _set_size_penalty, ECCCo.distance_from_energy] λ = λ isa AbstractFloat ? [0.0, λ, λ] : λ return Generator(; penalty=_penalties, λ=λ, opt=opt, kwargs...) end diff --git a/src/penalties.jl b/src/penalties.jl index 19f06ae48de244bc691df38268ec92322554b139..92b007d06e219bf2bb3b69952d286619a0c5f8a7 100644 --- a/src/penalties.jl +++ b/src/penalties.jl @@ -1,5 +1,6 @@ using ChainRules: ignore_derivatives using Distances +using Flux using LinearAlgebra: norm using Statistics: mean @@ -37,7 +38,7 @@ end function distance_from_energy( ce::AbstractCounterfactualExplanation; - n::Int=10, niter=250, from_buffer=true, agg=mean, kwargs... + n::Int=10, niter=100, from_buffer=true, agg=mean, kwargs... ) conditional_samples = [] ignore_derivatives() do @@ -46,7 +47,7 @@ function distance_from_energy( _dict[:energy_sampler] = ECCCo.EnergySampler(ce; niter=niter, nsamples=n, kwargs...) end sampler = _dict[:energy_sampler] - push!(conditional_samples, rand(sampler, n; from_buffer=from_buffer)) + push!(conditional_samples, rand(sampler, 100; from_buffer=from_buffer)) end x′ = CounterfactualExplanations.counterfactual(ce) loss = map(eachslice(x′, dims=ndims(x′))) do x