diff --git a/artifacts/results/images/mnist_eccco.png b/artifacts/results/images/mnist_eccco.png index 217094c975213d0afc36922c9b9d0d1653d1b7d4..4d91b401d4b88aaeaae8660904fe687f128f74b3 100644 Binary files a/artifacts/results/images/mnist_eccco.png and b/artifacts/results/images/mnist_eccco.png differ diff --git a/artifacts/results/images/mnist_eccco_benchmark.png b/artifacts/results/images/mnist_eccco_benchmark.png index 9bf388b70297582d15f3accc2b6da83d46e512f1..49b590bc15223b0bf69a600c0648a91cc578c048 100644 Binary files a/artifacts/results/images/mnist_eccco_benchmark.png and b/artifacts/results/images/mnist_eccco_benchmark.png differ diff --git a/artifacts/results/mnist_vae.jls b/artifacts/results/mnist_vae.jls index e571c3f2ac2db78575bfa61439e4da7b817bb3f7..b0e32960a27364c9cf284cd5d8409ec363dd5299 100644 Binary files a/artifacts/results/mnist_vae.jls and b/artifacts/results/mnist_vae.jls differ diff --git a/artifacts/results/mnist_vae_weak.jls b/artifacts/results/mnist_vae_weak.jls index 5717d5eeb1a7e905354727a7553c701f328a93be..d3db5f46b841781aeaa993c20653bae898dc0aad 100644 Binary files a/artifacts/results/mnist_vae_weak.jls and b/artifacts/results/mnist_vae_weak.jls differ diff --git a/notebooks/mnist.qmd b/notebooks/mnist.qmd index e8b52e2313bb4eed9641ade3ceef39b8925664fb..82d559fddefc8f283e66c3366ec4ff6fb9e9f48b 100644 --- a/notebooks/mnist.qmd +++ b/notebooks/mnist.qmd @@ -137,7 +137,6 @@ savefig(plt, joinpath(output_images_path, "surrogate_gone_wrong.png")) ```{julia} function pre_process(x; noise::Float32=0.03f0) ϵ = Float32.(randn(size(x)) * noise) - x = @.(2 * x - 1) x += ϵ return x end @@ -149,7 +148,7 @@ _retrain = false _regen = false # Data: -n_obs = 10000 +n_obs = 1000 counterfactual_data = load_mnist(n_obs) counterfactual_data.X = pre_process.(counterfactual_data.X) X, y = CounterfactualExplanations.DataPreprocessing.unpack_data(counterfactual_data) @@ -187,7 +186,7 @@ sampler = ConditionalSampler( input_size=(input_dim,), batch_size=10, ) -α = [1.0,1.0,1e-2] # penalty strengths +α = [1.0,1.0,1e-1] # penalty strengths ``` ```{julia} @@ -307,7 +306,6 @@ model_performance = DataFrame() for (mod_name, mod) in model_dict # Test performance: test_data = load_mnist_test() - test_data.X = pre_process.(test_data.X, noise=0.0f0) _perf = CounterfactualExplanations.Models.model_evaluation(mod, test_data, measure=collect(values(measure))) _perf = DataFrame([[p] for p in _perf], collect(keys(measure))) _perf.mod_name .= mod_name @@ -321,58 +319,132 @@ model_performance ### Different Models ```{julia} -plt_order = ["MLP", "MLP Ensemble", "JEM", "JEM Ensemble"] - -# ECCCo: -λ=[0.5,0.1,0.5] -temp=0.5 -η=0.01 - -# Generate counterfactuals using ECCCo generator: -eccco_generator = ECCCoGenerator( - λ=λ, - temp=temp, - opt=Flux.Optimise.Adam(η), +function _plot_eccco_mnist( + x::Union{AbstractArray, Int}=x_factual, target::Int=target; + λ=[0.1,0.1,0.1], + temp=0.1,η=0.01, + plt_order = ["MLP", "MLP Ensemble", "JEM", "JEM Ensemble"], + opt = Flux.Optimise.Adam(η), + rng::Union{Int,AbstractRNG}=Random.GLOBAL_RNG, ) -ces = Dict() -for (mod_name, mod) in model_dict - ce = generate_counterfactual( - x_factual, target, counterfactual_data, mod, eccco_generator; - decision_threshold=γ, max_iter=T, - initialization=:identity, - converge_when=:generator_conditions, + # Setup: + Random.seed!(rng) + if x isa Int + x = reshape(counterfactual_data.X[:,rand(findall(labels.==x))],input_dim,1) + end + + # Generate counterfactuals using ECCCo generator: + eccco_generator = ECCCoGenerator( + λ=λ, + temp=temp, + opt=opt, + ) + + ces = Dict() + for (mod_name, mod) in model_dict + ce = generate_counterfactual( + x, target, counterfactual_data, mod, eccco_generator; + decision_threshold=γ, max_iter=T, + initialization=:identity, + converge_when=:generator_conditions, + ) + ces[mod_name] = ce + end + _plt_order = map(x -> findall(collect(keys(model_dict)) .== x)[1], plt_order) + + # Plot: + p1 = Plots.plot( + convert2image(MNIST, reshape(x,28,28)), + axis=nothing, + size=(img_height, img_height), + title="Factual" ) - ces[mod_name] = ce + + plts = [] + for (_name,ce) in ces + _x = CounterfactualExplanations.counterfactual(ce) + _phat = target_probs(ce) + _title = "$_name (p̂=$(round(_phat[1]; digits=3)))" + plt = Plots.plot( + convert2image(MNIST, reshape(_x,28,28)), + axis=nothing, + size=(img_height, img_height), + title=_title + ) + plts = [plts..., plt] + end + plts = plts[_plt_order] + plts = [p1, plts...] + plt = Plots.plot(plts...; size=(img_height*length(plts),img_height), layout=(1,length(plts))) + + return plt, eccco_generator end -_plt_order = map(x -> findall(collect(keys(model_dict)) .== x)[1], plt_order) +``` -# Plot: -p1 = Plots.plot( - convert2image(MNIST, reshape(x_factual,28,28)), - axis=nothing, - size=(img_height, img_height), - title="Factual" +```{julia} +plt, eccco_generator = _plot_eccco_mnist() +display(plt) +savefig(plt, joinpath(output_images_path, "mnist_eccco.png")) +``` + +### All digits + +```{julia} +function plot_mnist( + factual::Int,target::Int; + generator::AbstractGenerator, + model::AbstractFittedModel=model_dict["JEM Ensemble"], + data::CounterfactualData=counterfactual_data, + rng::Union{Int,AbstractRNG}=Random.GLOBAL_RNG, + _plot_title::Bool=true, + kwargs..., ) + decision_threshold = !isdefined(kwargs, :decision_threshold) ? 0.9 : decision_threshold + max_iter = !isdefined(kwargs, :max_iter) ? 100 : max_iter + initialization = !isdefined(kwargs, :initialization) ? :identity : initialization + converge_when = !isdefined(kwargs, :converge_when) ? :generator_conditions : converge_when + + x = reshape(data.X[:,rand(findall(predict_label(model, data).==factual))],input_dim,1) + ce = generate_counterfactual( + x, target, data, model, generator; + decision_threshold=decision_threshold, max_iter=max_iter, + initialization=initialization, + converge_when=converge_when, + kwargs... + ) + + _title = _plot_title ? "$(factual) -> $(target)" : "" -plts = [] -for (_name,ce) in ces _x = CounterfactualExplanations.counterfactual(ce) - _phat = target_probs(ce) - _title = "$_name (p̂=$(round(_phat[1]; digits=3)))" plt = Plots.plot( convert2image(MNIST, reshape(_x,28,28)), axis=nothing, size=(img_height, img_height), title=_title ) - plts = [plts..., plt] + return plt +end +``` + +```{julia} +if _regen + function plot_all_digits(rng=1;verbose=true,kwargs...) + plts = [] + for i in 0:9 + for j in 0:9 + @info "Generating counterfactual for $(i) -> $(j)" + plt = plot_mnist(i,j;kwargs...,rng=rng) + !verbose || display(plt) + plts = [plts..., plt] + end + end + plt = Plots.plot(plts...; size=(img_height*10,img_height*10), layout=(10,10)) + return plt + end + plt = plot_all_digits(generator=eccco_generator) + savefig(plt, joinpath(output_images_path, "mnist_eccco_all_digits.png")) end -plts = plts[_plt_order] -plts = [p1, plts...] -plt = Plots.plot(plts...; size=(img_height*length(plts),img_height), layout=(1,length(plts))) -display(plt) -savefig(plt, joinpath(output_images_path, "mnist_eccco.png")) ``` ### Different Generators @@ -432,21 +504,12 @@ savefig(plt, joinpath(output_images_path, "mnist_eccco_benchmark.png")) ## Benchmark ```{julia} -# Benchmark generators: -generators = Dict( - :wachter => GenericGenerator(opt=opt, λ=l2_λ), - :revise => REVISEGenerator(opt=opt, λ=l2_λ), - :greedy => GreedyGenerator(), -) - -# Conformal Models: - - # Measures: measures = [ CounterfactualExplanations.distance, ECCCo.distance_from_energy, ECCCo.distance_from_targets, CounterfactualExplanations.validity, + CounterfactualExplanations.redudancy, ] ``` \ No newline at end of file diff --git a/paper/paper.pdf b/paper/paper.pdf index 2859dae90c8dc3dd0fa0be53ee9a0557c3ac9354..b7ffb3b71de08967964800dacafa97d294ab1dfa 100644 Binary files a/paper/paper.pdf and b/paper/paper.pdf differ diff --git a/src/generator.jl b/src/generator.jl index 72952aa31de89bf7d1a8d297a8337abe97cf9a16..87b1dceff3e45e92075efdfb5a69f614eb15e2af 100644 --- a/src/generator.jl +++ b/src/generator.jl @@ -28,7 +28,7 @@ function ECCCoGenerator(; function _set_size_penalty(ce::AbstractCounterfactualExplanation) return ECCCo.set_size_penalty(ce; κ=κ, temp=temp) end - _penalties = [Objectives.distance_l2, _set_size_penalty, ECCCo.distance_from_energy] + _penalties = [Objectives.distance_l1, _set_size_penalty, ECCCo.distance_from_energy] λ = λ isa AbstractFloat ? [0.0, λ, λ] : λ return Generator(; penalty=_penalties, λ=λ, opt=opt, kwargs...) end diff --git a/src/penalties.jl b/src/penalties.jl index 19f06ae48de244bc691df38268ec92322554b139..92b007d06e219bf2bb3b69952d286619a0c5f8a7 100644 --- a/src/penalties.jl +++ b/src/penalties.jl @@ -1,5 +1,6 @@ using ChainRules: ignore_derivatives using Distances +using Flux using LinearAlgebra: norm using Statistics: mean @@ -37,7 +38,7 @@ end function distance_from_energy( ce::AbstractCounterfactualExplanation; - n::Int=10, niter=250, from_buffer=true, agg=mean, kwargs... + n::Int=10, niter=100, from_buffer=true, agg=mean, kwargs... ) conditional_samples = [] ignore_derivatives() do @@ -46,7 +47,7 @@ function distance_from_energy( _dict[:energy_sampler] = ECCCo.EnergySampler(ce; niter=niter, nsamples=n, kwargs...) end sampler = _dict[:energy_sampler] - push!(conditional_samples, rand(sampler, n; from_buffer=from_buffer)) + push!(conditional_samples, rand(sampler, 100; from_buffer=from_buffer)) end x′ = CounterfactualExplanations.counterfactual(ce) loss = map(eachslice(x′, dims=ndims(x′))) do x