diff --git a/artifacts/results/images/mnist_eccco.png b/artifacts/results/images/mnist_eccco.png index e1a9d1ced77de0967ab201950f7c9a07317c183d..6670bc2859abf0b1203839b395f242bafd870637 100644 Binary files a/artifacts/results/images/mnist_eccco.png and b/artifacts/results/images/mnist_eccco.png differ diff --git a/artifacts/results/images/mnist_eccco_benchmark.png b/artifacts/results/images/mnist_eccco_benchmark.png index 49b590bc15223b0bf69a600c0648a91cc578c048..71f88516662460027dce29a2956e127226dfe2eb 100644 Binary files a/artifacts/results/images/mnist_eccco_benchmark.png and b/artifacts/results/images/mnist_eccco_benchmark.png differ diff --git a/artifacts/results/mnist_vae.jls b/artifacts/results/mnist_vae.jls index 93409644f8fdc0c48c85e12c72288c9d285443f4..0d7f33df172d6addc14b3b86d77b6aeef377b66d 100644 Binary files a/artifacts/results/mnist_vae.jls and b/artifacts/results/mnist_vae.jls differ diff --git a/artifacts/results/mnist_vae_weak.jls b/artifacts/results/mnist_vae_weak.jls index 95053840467a78ebab6998e996a42fdcca25bc42..b297bd7255683243eb19f6ba6d2cd3c9019d689f 100644 Binary files a/artifacts/results/mnist_vae_weak.jls and b/artifacts/results/mnist_vae_weak.jls differ diff --git a/notebooks/mnist.qmd b/notebooks/mnist.qmd index cbb323106029ce5358ee74cf3ce99205e6aae973..c1fb622c07f3e836e6d66f21dd0cb8a996f0159c 100644 --- a/notebooks/mnist.qmd +++ b/notebooks/mnist.qmd @@ -520,7 +520,8 @@ bmk = benchmark( generators=generator_dict, measure=measures, suppress_training=true, dataname="MNIST", - n_individuals=100, + n_individuals=5, + factual=0, target=1, initialization=:identity, ) @@ -533,6 +534,6 @@ Serialization.serialize(joinpath(output_path, "mnist_benchmark.jls"), bmk) @group_by(dataname, generator, model, variable) @summarize(mean=mean(value),sd=std(value)) @ungroup - @filter(variable == "distance_from_targets") + @filter(variable == "distance_from_energy") end ``` \ No newline at end of file diff --git a/paper/paper.pdf b/paper/paper.pdf index 661c91d583d80ffd8c84daf9afdd5d5f7c4f5ede..4444c69eb76435de66017be4153aaa16f3728b2b 100644 Binary files a/paper/paper.pdf and b/paper/paper.pdf differ diff --git a/src/penalties.jl b/src/penalties.jl index 820b2f3ba6874c61d4bb87badac676037f562b06..3afe195ff82a60f33934bf1c5a5bd584f592f478 100644 --- a/src/penalties.jl +++ b/src/penalties.jl @@ -38,7 +38,7 @@ end function distance_from_energy( ce::AbstractCounterfactualExplanation; - n::Int=10, niter=60, from_buffer=true, agg=mean, kwargs... + n::Int=10, niter=100, from_buffer=true, agg=mean, kwargs... ) conditional_samples = [] ignore_derivatives() do @@ -64,7 +64,7 @@ end function distance_from_targets( ce::AbstractCounterfactualExplanation; - n::Int=100, agg=mean + n::Int=1000, agg=mean ) target_idx = ce.data.output_encoder.labels .== ce.target target_samples = ce.data.X[:,target_idx] |>