From a0664452e0f56e398b68ead7b7c981912093b92d Mon Sep 17 00:00:00 2001
From: Pat Alt <55311242+pat-alt@users.noreply.github.com>
Date: Thu, 10 Aug 2023 15:30:25 +0200
Subject: [PATCH] uh

---
 notebooks/mnist.qmd | 31 ++++++++++++++++++++++++++++++-
 src/penalties.jl    | 10 +++++-----
 2 files changed, 35 insertions(+), 6 deletions(-)

diff --git a/notebooks/mnist.qmd b/notebooks/mnist.qmd
index c63bece4..91b18b15 100644
--- a/notebooks/mnist.qmd
+++ b/notebooks/mnist.qmd
@@ -473,6 +473,12 @@ savefig(plt, joinpath(output_images_path, "mnist_eccco.png"))
 
 #### Energy Delta (not in paper)
 
+```{julia}
+plt, gen_delta, ces = _plot_eccco_mnist(λ = [0.1,0.1,3.0], use_energy_delta=true)
+display(plt)
+savefig(plt, joinpath(output_images_path, "mnist_eccco_energy_delta.png"))
+```
+
 ```{julia}
 λ_delta = [0.1,0.1,2.5]
 λ = [0.1,0.25,0.25]
@@ -802,6 +808,29 @@ if _regen_all_digits
 end
 ```
 
+#### Energy Delta (not in paper)
+
+```{julia}
+_regen_all_digits = true
+if _regen_all_digits
+    function plot_all_digits(rng=123;verbose=true,img_height=180,kwargs...)
+        plts = []
+        for i in 0:9
+            for j in 0:9
+                @info "Generating counterfactual for $(i) -> $(j)"
+                plt = plot_mnist(i,j;kwargs...,rng=rng, img_height=img_height)
+                !verbose || display(plt)
+                plts = [plts..., plt]
+            end
+        end
+        plt = Plots.plot(plts...; size=(img_height*10,img_height*10), layout=(10,10), dpi=300)
+        return plt
+    end
+    plt = plot_all_digits(generator=gen_delta)
+    savefig(plt, joinpath(output_images_path, "mnist_eccco_all_digits-delta.png"))
+end
+```
+
 ## Benchmark
 
 ```{julia}
@@ -882,4 +911,4 @@ plt = draw(
 )   
 display(plt)
 save(joinpath(output_images_path, "mnist_benchmark.png"), plt, px_per_unit=5)
-```
\ No newline at end of file
+```
diff --git a/src/penalties.jl b/src/penalties.jl
index 7ea7225f..cba61638 100644
--- a/src/penalties.jl
+++ b/src/penalties.jl
@@ -72,13 +72,13 @@ function energy_delta(
         end
     end
 
-    xtarget = conditional_samples[1]                    # conditional samples
-    x = CounterfactualExplanations.decode_state(ce)     # current state
+    xgenerated = conditional_samples[1]                         # conditional samples
+    xproposed = CounterfactualExplanations.decode_state(ce)     # current state
     t = get_target_index(ce.data.y_levels, ce.target)
-    E(x) = -logits(ce.M, x)[t,:]                # negative logits for target class
-    _loss = E(x) .- E(xtarget)
+    E(x) = -logits(ce.M, x)[t,:]                                # negative logits for target class
+    _loss = E(xproposed) .- E(xgenerated)
 
-    _loss = reduce((x, y) -> x + y, _loss) / n          # aggregate over samples
+    _loss = reduce((x, y) -> x + y, _loss) / n                  # aggregate over samples
 
     if return_conditionals
         return conditional_samples[1]
-- 
GitLab