Skip to content
Snippets Groups Projects
Commit dddf8988 authored by Pat Alt's avatar Pat Alt
Browse files

mnist benchmark set up

parent 39721697
No related branches found
No related tags found
No related merge requests found
artifacts/results/images/mnist_eccco.png

22.2 KiB | W: | H:

artifacts/results/images/mnist_eccco.png

24.2 KiB | W: | H:

artifacts/results/images/mnist_eccco.png
artifacts/results/images/mnist_eccco.png
artifacts/results/images/mnist_eccco.png
artifacts/results/images/mnist_eccco.png
  • 2-up
  • Swipe
  • Onion skin
No preview for this file type
No preview for this file type
......@@ -148,9 +148,10 @@ _retrain = false
_regen = false
# Data:
n_obs = 1000
n_obs = 10000
counterfactual_data = load_mnist(n_obs)
counterfactual_data.X = pre_process.(counterfactual_data.X)
counterfactual_data.generative_model = vae
X, y = CounterfactualExplanations.DataPreprocessing.unpack_data(counterfactual_data)
X = table(permutedims(X))
x_factual = reshape(pre_process(x_factual, noise=0.0f0), input_dim, 1)
......@@ -186,7 +187,7 @@ sampler = ConditionalSampler(
input_size=(input_dim,),
batch_size=10,
)
α = [1.0,1.0,1e-1] # penalty strengths
α = [1.0,1.0,1e-2] # penalty strengths
```
```{julia}
......@@ -258,6 +259,7 @@ end
```{julia}
# Plot generated samples:
n_regen = 150
if _regen
for (mod_name, mod) in model_dict
if ECCCo._has_sampler(mod)
......@@ -272,12 +274,11 @@ if _regen
opt = ImproperSGLD()
f(x) = logits(mod, x)
n_iter = 200
_w = 1500
plts = []
neach = 10
for i in 1:10
x = sampler(f, opt; niter=n_iter, n_samples=neach, y=i)
x = sampler(f, opt; niter=n_regen, n_samples=neach, y=i)
plts_i = []
for j in 1:size(x, 2)
xj = x[:,j]
......@@ -321,7 +322,7 @@ model_performance
```{julia}
function _plot_eccco_mnist(
x::Union{AbstractArray, Int}=x_factual, target::Int=target;
λ=[0.1,0.1,0.1],
λ=[0.5,0.1,0.5],
temp=0.1,η=0.01,
plt_order = ["MLP", "MLP Ensemble", "JEM", "JEM Ensemble"],
opt = Flux.Optimise.Adam(η),
......@@ -509,7 +510,29 @@ measures = [
CounterfactualExplanations.distance,
ECCCo.distance_from_energy,
ECCCo.distance_from_targets,
CounterfactualExplanations.validity,
CounterfactualExplanations.redudancy,
CounterfactualExplanations.Evaluation.validity,
CounterfactualExplanations.Evaluation.redundancy,
]
bmk = benchmark(
counterfactual_data;
models=model_dict,
generators=generator_dict,
measure=measures,
suppress_training=true, dataname="MNIST",
n_individuals=100,
initialization=:identity,
)
Serialization.serialize(joinpath(output_path, "mnist_benchmark.jls"), bmk)
```
```{julia}
@chain bmk() begin
@group_by(dataname, generator, model, variable)
@summarize(mean=mean(value),sd=std(value))
@ungroup
@filter(variable == "distance_from_targets")
end
```
\ No newline at end of file
No preview for this file type
......@@ -38,7 +38,7 @@ end
function distance_from_energy(
ce::AbstractCounterfactualExplanation;
n::Int=10, niter=100, from_buffer=true, agg=mean, kwargs...
n::Int=10, niter=60, from_buffer=true, agg=mean, kwargs...
)
conditional_samples = []
ignore_derivatives() do
......@@ -47,7 +47,7 @@ function distance_from_energy(
_dict[:energy_sampler] = ECCCo.EnergySampler(ce; niter=niter, nsamples=n, kwargs...)
end
sampler = _dict[:energy_sampler]
push!(conditional_samples, rand(sampler, 100; from_buffer=from_buffer))
push!(conditional_samples, rand(sampler, n; from_buffer=from_buffer))
end
x′ = CounterfactualExplanations.counterfactual(ce)
loss = map(eachslice(x′, dims=ndims(x′))) do x
......@@ -64,16 +64,15 @@ end
function distance_from_targets(
ce::AbstractCounterfactualExplanation;
n::Int=1000, agg=mean
n::Int=100, agg=mean
)
target_idx = ce.data.output_encoder.labels .== ce.target
target_samples = ce.data.X[:,target_idx] |>
X -> X[:,rand(1:end,n)]
x′ = CounterfactualExplanations.counterfactual(ce)
loss = map(eachslice(x′, dims=3)) do x
x = Matrix(x)
loss = map(eachslice(x′, dims=ndims(x′))) do x
Δ = map(eachcol(target_samples)) do xsample
norm(x - xsample)
norm(x - xsample, 1)
end
return mean(Δ)
end
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment