diff --git a/artifacts/results/images/mnist_eccco.png b/artifacts/results/images/mnist_eccco.png index 4d91b401d4b88aaeaae8660904fe687f128f74b3..e1a9d1ced77de0967ab201950f7c9a07317c183d 100644 Binary files a/artifacts/results/images/mnist_eccco.png and b/artifacts/results/images/mnist_eccco.png differ diff --git a/artifacts/results/mnist_vae.jls b/artifacts/results/mnist_vae.jls index b0e32960a27364c9cf284cd5d8409ec363dd5299..93409644f8fdc0c48c85e12c72288c9d285443f4 100644 Binary files a/artifacts/results/mnist_vae.jls and b/artifacts/results/mnist_vae.jls differ diff --git a/artifacts/results/mnist_vae_weak.jls b/artifacts/results/mnist_vae_weak.jls index d3db5f46b841781aeaa993c20653bae898dc0aad..95053840467a78ebab6998e996a42fdcca25bc42 100644 Binary files a/artifacts/results/mnist_vae_weak.jls and b/artifacts/results/mnist_vae_weak.jls differ diff --git a/notebooks/mnist.qmd b/notebooks/mnist.qmd index 82d559fddefc8f283e66c3366ec4ff6fb9e9f48b..cbb323106029ce5358ee74cf3ce99205e6aae973 100644 --- a/notebooks/mnist.qmd +++ b/notebooks/mnist.qmd @@ -148,9 +148,10 @@ _retrain = false _regen = false # Data: -n_obs = 1000 +n_obs = 10000 counterfactual_data = load_mnist(n_obs) counterfactual_data.X = pre_process.(counterfactual_data.X) +counterfactual_data.generative_model = vae X, y = CounterfactualExplanations.DataPreprocessing.unpack_data(counterfactual_data) X = table(permutedims(X)) x_factual = reshape(pre_process(x_factual, noise=0.0f0), input_dim, 1) @@ -186,7 +187,7 @@ sampler = ConditionalSampler( input_size=(input_dim,), batch_size=10, ) -α = [1.0,1.0,1e-1] # penalty strengths +α = [1.0,1.0,1e-2] # penalty strengths ``` ```{julia} @@ -258,6 +259,7 @@ end ```{julia} # Plot generated samples: +n_regen = 150 if _regen for (mod_name, mod) in model_dict if ECCCo._has_sampler(mod) @@ -272,12 +274,11 @@ if _regen opt = ImproperSGLD() f(x) = logits(mod, x) - n_iter = 200 _w = 1500 plts = [] neach = 10 for i in 1:10 - x = sampler(f, opt; niter=n_iter, n_samples=neach, y=i) + x = sampler(f, opt; niter=n_regen, n_samples=neach, y=i) plts_i = [] for j in 1:size(x, 2) xj = x[:,j] @@ -321,7 +322,7 @@ model_performance ```{julia} function _plot_eccco_mnist( x::Union{AbstractArray, Int}=x_factual, target::Int=target; - λ=[0.1,0.1,0.1], + λ=[0.5,0.1,0.5], temp=0.1,η=0.01, plt_order = ["MLP", "MLP Ensemble", "JEM", "JEM Ensemble"], opt = Flux.Optimise.Adam(η), @@ -509,7 +510,29 @@ measures = [ CounterfactualExplanations.distance, ECCCo.distance_from_energy, ECCCo.distance_from_targets, - CounterfactualExplanations.validity, - CounterfactualExplanations.redudancy, + CounterfactualExplanations.Evaluation.validity, + CounterfactualExplanations.Evaluation.redundancy, ] + +bmk = benchmark( + counterfactual_data; + models=model_dict, + generators=generator_dict, + measure=measures, + suppress_training=true, dataname="MNIST", + n_individuals=100, + initialization=:identity, +) + +Serialization.serialize(joinpath(output_path, "mnist_benchmark.jls"), bmk) +``` + + +```{julia} +@chain bmk() begin + @group_by(dataname, generator, model, variable) + @summarize(mean=mean(value),sd=std(value)) + @ungroup + @filter(variable == "distance_from_targets") +end ``` \ No newline at end of file diff --git a/paper/paper.pdf b/paper/paper.pdf index b7ffb3b71de08967964800dacafa97d294ab1dfa..661c91d583d80ffd8c84daf9afdd5d5f7c4f5ede 100644 Binary files a/paper/paper.pdf and b/paper/paper.pdf differ diff --git a/src/penalties.jl b/src/penalties.jl index 92b007d06e219bf2bb3b69952d286619a0c5f8a7..820b2f3ba6874c61d4bb87badac676037f562b06 100644 --- a/src/penalties.jl +++ b/src/penalties.jl @@ -38,7 +38,7 @@ end function distance_from_energy( ce::AbstractCounterfactualExplanation; - n::Int=10, niter=100, from_buffer=true, agg=mean, kwargs... + n::Int=10, niter=60, from_buffer=true, agg=mean, kwargs... ) conditional_samples = [] ignore_derivatives() do @@ -47,7 +47,7 @@ function distance_from_energy( _dict[:energy_sampler] = ECCCo.EnergySampler(ce; niter=niter, nsamples=n, kwargs...) end sampler = _dict[:energy_sampler] - push!(conditional_samples, rand(sampler, 100; from_buffer=from_buffer)) + push!(conditional_samples, rand(sampler, n; from_buffer=from_buffer)) end x′ = CounterfactualExplanations.counterfactual(ce) loss = map(eachslice(x′, dims=ndims(x′))) do x @@ -64,16 +64,15 @@ end function distance_from_targets( ce::AbstractCounterfactualExplanation; - n::Int=1000, agg=mean + n::Int=100, agg=mean ) target_idx = ce.data.output_encoder.labels .== ce.target target_samples = ce.data.X[:,target_idx] |> X -> X[:,rand(1:end,n)] x′ = CounterfactualExplanations.counterfactual(ce) - loss = map(eachslice(x′, dims=3)) do x - x = Matrix(x) + loss = map(eachslice(x′, dims=ndims(x′))) do x Δ = map(eachcol(target_samples)) do xsample - norm(x - xsample) + norm(x - xsample, 1) end return mean(Δ) end