Skip to content
Snippets Groups Projects
Commit 2eca812d authored by Pat Alt's avatar Pat Alt
Browse files

for 0 to 1 counterfactuals are looking a bit to homogenous

parent 22e3b6ee
No related branches found
No related tags found
No related merge requests found
artifacts/results/images/mnist_eccco.png

18.2 KiB | W: | H:

artifacts/results/images/mnist_eccco.png

22.2 KiB | W: | H:

artifacts/results/images/mnist_eccco.png
artifacts/results/images/mnist_eccco.png
artifacts/results/images/mnist_eccco.png
artifacts/results/images/mnist_eccco.png
  • 2-up
  • Swipe
  • Onion skin
artifacts/results/images/mnist_eccco_benchmark.png

20.3 KiB | W: | H:

artifacts/results/images/mnist_eccco_benchmark.png

22.2 KiB | W: | H:

artifacts/results/images/mnist_eccco_benchmark.png
artifacts/results/images/mnist_eccco_benchmark.png
artifacts/results/images/mnist_eccco_benchmark.png
artifacts/results/images/mnist_eccco_benchmark.png
  • 2-up
  • Swipe
  • Onion skin
No preview for this file type
No preview for this file type
......@@ -137,7 +137,6 @@ savefig(plt, joinpath(output_images_path, "surrogate_gone_wrong.png"))
```{julia}
function pre_process(x; noise::Float32=0.03f0)
ϵ = Float32.(randn(size(x)) * noise)
x = @.(2 * x - 1)
x += ϵ
return x
end
......@@ -149,7 +148,7 @@ _retrain = false
_regen = false
# Data:
n_obs = 10000
n_obs = 1000
counterfactual_data = load_mnist(n_obs)
counterfactual_data.X = pre_process.(counterfactual_data.X)
X, y = CounterfactualExplanations.DataPreprocessing.unpack_data(counterfactual_data)
......@@ -187,7 +186,7 @@ sampler = ConditionalSampler(
input_size=(input_dim,),
batch_size=10,
)
α = [1.0,1.0,1e-2] # penalty strengths
α = [1.0,1.0,1e-1] # penalty strengths
```
```{julia}
......@@ -307,7 +306,6 @@ model_performance = DataFrame()
for (mod_name, mod) in model_dict
# Test performance:
test_data = load_mnist_test()
test_data.X = pre_process.(test_data.X, noise=0.0f0)
_perf = CounterfactualExplanations.Models.model_evaluation(mod, test_data, measure=collect(values(measure)))
_perf = DataFrame([[p] for p in _perf], collect(keys(measure)))
_perf.mod_name .= mod_name
......@@ -321,58 +319,132 @@ model_performance
### Different Models
```{julia}
plt_order = ["MLP", "MLP Ensemble", "JEM", "JEM Ensemble"]
# ECCCo:
λ=[0.5,0.1,0.5]
temp=0.5
η=0.01
# Generate counterfactuals using ECCCo generator:
eccco_generator = ECCCoGenerator(
λ=λ,
temp=temp,
opt=Flux.Optimise.Adam(η),
function _plot_eccco_mnist(
x::Union{AbstractArray, Int}=x_factual, target::Int=target;
λ=[0.1,0.1,0.1],
temp=0.1,η=0.01,
plt_order = ["MLP", "MLP Ensemble", "JEM", "JEM Ensemble"],
opt = Flux.Optimise.Adam(η),
rng::Union{Int,AbstractRNG}=Random.GLOBAL_RNG,
)
ces = Dict()
for (mod_name, mod) in model_dict
ce = generate_counterfactual(
x_factual, target, counterfactual_data, mod, eccco_generator;
decision_threshold=γ, max_iter=T,
initialization=:identity,
converge_when=:generator_conditions,
# Setup:
Random.seed!(rng)
if x isa Int
x = reshape(counterfactual_data.X[:,rand(findall(labels.==x))],input_dim,1)
end
# Generate counterfactuals using ECCCo generator:
eccco_generator = ECCCoGenerator(
λ=λ,
temp=temp,
opt=opt,
)
ces = Dict()
for (mod_name, mod) in model_dict
ce = generate_counterfactual(
x, target, counterfactual_data, mod, eccco_generator;
decision_threshold=γ, max_iter=T,
initialization=:identity,
converge_when=:generator_conditions,
)
ces[mod_name] = ce
end
_plt_order = map(x -> findall(collect(keys(model_dict)) .== x)[1], plt_order)
# Plot:
p1 = Plots.plot(
convert2image(MNIST, reshape(x,28,28)),
axis=nothing,
size=(img_height, img_height),
title="Factual"
)
ces[mod_name] = ce
plts = []
for (_name,ce) in ces
_x = CounterfactualExplanations.counterfactual(ce)
_phat = target_probs(ce)
_title = "$_name (p̂=$(round(_phat[1]; digits=3)))"
plt = Plots.plot(
convert2image(MNIST, reshape(_x,28,28)),
axis=nothing,
size=(img_height, img_height),
title=_title
)
plts = [plts..., plt]
end
plts = plts[_plt_order]
plts = [p1, plts...]
plt = Plots.plot(plts...; size=(img_height*length(plts),img_height), layout=(1,length(plts)))
return plt, eccco_generator
end
_plt_order = map(x -> findall(collect(keys(model_dict)) .== x)[1], plt_order)
```
# Plot:
p1 = Plots.plot(
convert2image(MNIST, reshape(x_factual,28,28)),
axis=nothing,
size=(img_height, img_height),
title="Factual"
```{julia}
plt, eccco_generator = _plot_eccco_mnist()
display(plt)
savefig(plt, joinpath(output_images_path, "mnist_eccco.png"))
```
### All digits
```{julia}
function plot_mnist(
factual::Int,target::Int;
generator::AbstractGenerator,
model::AbstractFittedModel=model_dict["JEM Ensemble"],
data::CounterfactualData=counterfactual_data,
rng::Union{Int,AbstractRNG}=Random.GLOBAL_RNG,
_plot_title::Bool=true,
kwargs...,
)
decision_threshold = !isdefined(kwargs, :decision_threshold) ? 0.9 : decision_threshold
max_iter = !isdefined(kwargs, :max_iter) ? 100 : max_iter
initialization = !isdefined(kwargs, :initialization) ? :identity : initialization
converge_when = !isdefined(kwargs, :converge_when) ? :generator_conditions : converge_when
x = reshape(data.X[:,rand(findall(predict_label(model, data).==factual))],input_dim,1)
ce = generate_counterfactual(
x, target, data, model, generator;
decision_threshold=decision_threshold, max_iter=max_iter,
initialization=initialization,
converge_when=converge_when,
kwargs...
)
_title = _plot_title ? "$(factual) -> $(target)" : ""
plts = []
for (_name,ce) in ces
_x = CounterfactualExplanations.counterfactual(ce)
_phat = target_probs(ce)
_title = "$_name (p̂=$(round(_phat[1]; digits=3)))"
plt = Plots.plot(
convert2image(MNIST, reshape(_x,28,28)),
axis=nothing,
size=(img_height, img_height),
title=_title
)
plts = [plts..., plt]
return plt
end
```
```{julia}
if _regen
function plot_all_digits(rng=1;verbose=true,kwargs...)
plts = []
for i in 0:9
for j in 0:9
@info "Generating counterfactual for $(i) -> $(j)"
plt = plot_mnist(i,j;kwargs...,rng=rng)
!verbose || display(plt)
plts = [plts..., plt]
end
end
plt = Plots.plot(plts...; size=(img_height*10,img_height*10), layout=(10,10))
return plt
end
plt = plot_all_digits(generator=eccco_generator)
savefig(plt, joinpath(output_images_path, "mnist_eccco_all_digits.png"))
end
plts = plts[_plt_order]
plts = [p1, plts...]
plt = Plots.plot(plts...; size=(img_height*length(plts),img_height), layout=(1,length(plts)))
display(plt)
savefig(plt, joinpath(output_images_path, "mnist_eccco.png"))
```
### Different Generators
......@@ -432,21 +504,12 @@ savefig(plt, joinpath(output_images_path, "mnist_eccco_benchmark.png"))
## Benchmark
```{julia}
# Benchmark generators:
generators = Dict(
:wachter => GenericGenerator(opt=opt, λ=l2_λ),
:revise => REVISEGenerator(opt=opt, λ=l2_λ),
:greedy => GreedyGenerator(),
)
# Conformal Models:
# Measures:
measures = [
CounterfactualExplanations.distance,
ECCCo.distance_from_energy,
ECCCo.distance_from_targets,
CounterfactualExplanations.validity,
CounterfactualExplanations.redudancy,
]
```
\ No newline at end of file
No preview for this file type
......@@ -28,7 +28,7 @@ function ECCCoGenerator(;
function _set_size_penalty(ce::AbstractCounterfactualExplanation)
return ECCCo.set_size_penalty(ce; κ=κ, temp=temp)
end
_penalties = [Objectives.distance_l2, _set_size_penalty, ECCCo.distance_from_energy]
_penalties = [Objectives.distance_l1, _set_size_penalty, ECCCo.distance_from_energy]
λ = λ isa AbstractFloat ? [0.0, λ, λ] : λ
return Generator(; penalty=_penalties, λ=λ, opt=opt, kwargs...)
end
......
using ChainRules: ignore_derivatives
using Distances
using Flux
using LinearAlgebra: norm
using Statistics: mean
......@@ -37,7 +38,7 @@ end
function distance_from_energy(
ce::AbstractCounterfactualExplanation;
n::Int=10, niter=250, from_buffer=true, agg=mean, kwargs...
n::Int=10, niter=100, from_buffer=true, agg=mean, kwargs...
)
conditional_samples = []
ignore_derivatives() do
......@@ -46,7 +47,7 @@ function distance_from_energy(
_dict[:energy_sampler] = ECCCo.EnergySampler(ce; niter=niter, nsamples=n, kwargs...)
end
sampler = _dict[:energy_sampler]
push!(conditional_samples, rand(sampler, n; from_buffer=from_buffer))
push!(conditional_samples, rand(sampler, 100; from_buffer=from_buffer))
end
x′ = CounterfactualExplanations.counterfactual(ce)
loss = map(eachslice(x′, dims=ndims(x′))) do x
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment