diff --git a/notebooks/mnist.qmd b/notebooks/mnist.qmd index c63bece4938fb69eac7e9a1dc898469692b76376..91b18b157d88c906f57a626f42b6d10e177cf814 100644 --- a/notebooks/mnist.qmd +++ b/notebooks/mnist.qmd @@ -473,6 +473,12 @@ savefig(plt, joinpath(output_images_path, "mnist_eccco.png")) #### Energy Delta (not in paper) +```{julia} +plt, gen_delta, ces = _plot_eccco_mnist(λ = [0.1,0.1,3.0], use_energy_delta=true) +display(plt) +savefig(plt, joinpath(output_images_path, "mnist_eccco_energy_delta.png")) +``` + ```{julia} λ_delta = [0.1,0.1,2.5] λ = [0.1,0.25,0.25] @@ -802,6 +808,29 @@ if _regen_all_digits end ``` +#### Energy Delta (not in paper) + +```{julia} +_regen_all_digits = true +if _regen_all_digits + function plot_all_digits(rng=123;verbose=true,img_height=180,kwargs...) + plts = [] + for i in 0:9 + for j in 0:9 + @info "Generating counterfactual for $(i) -> $(j)" + plt = plot_mnist(i,j;kwargs...,rng=rng, img_height=img_height) + !verbose || display(plt) + plts = [plts..., plt] + end + end + plt = Plots.plot(plts...; size=(img_height*10,img_height*10), layout=(10,10), dpi=300) + return plt + end + plt = plot_all_digits(generator=gen_delta) + savefig(plt, joinpath(output_images_path, "mnist_eccco_all_digits-delta.png")) +end +``` + ## Benchmark ```{julia} @@ -882,4 +911,4 @@ plt = draw( ) display(plt) save(joinpath(output_images_path, "mnist_benchmark.png"), plt, px_per_unit=5) -``` \ No newline at end of file +``` diff --git a/src/penalties.jl b/src/penalties.jl index 7ea7225f20d9056fcdb9287ff3a5f7157e51e0f3..cba61638c58e6007ce3ca35590fcf1eff5699c03 100644 --- a/src/penalties.jl +++ b/src/penalties.jl @@ -72,13 +72,13 @@ function energy_delta( end end - xtarget = conditional_samples[1] # conditional samples - x = CounterfactualExplanations.decode_state(ce) # current state + xgenerated = conditional_samples[1] # conditional samples + xproposed = CounterfactualExplanations.decode_state(ce) # current state t = get_target_index(ce.data.y_levels, ce.target) - E(x) = -logits(ce.M, x)[t,:] # negative logits for target class - _loss = E(x) .- E(xtarget) + E(x) = -logits(ce.M, x)[t,:] # negative logits for target class + _loss = E(xproposed) .- E(xgenerated) - _loss = reduce((x, y) -> x + y, _loss) / n # aggregate over samples + _loss = reduce((x, y) -> x + y, _loss) / n # aggregate over samples if return_conditionals return conditional_samples[1]