diff --git a/notebooks/mnist.qmd b/notebooks/mnist.qmd
index c63bece4938fb69eac7e9a1dc898469692b76376..91b18b157d88c906f57a626f42b6d10e177cf814 100644
--- a/notebooks/mnist.qmd
+++ b/notebooks/mnist.qmd
@@ -473,6 +473,12 @@ savefig(plt, joinpath(output_images_path, "mnist_eccco.png"))
 
 #### Energy Delta (not in paper)
 
+```{julia}
+plt, gen_delta, ces = _plot_eccco_mnist(λ = [0.1,0.1,3.0], use_energy_delta=true)
+display(plt)
+savefig(plt, joinpath(output_images_path, "mnist_eccco_energy_delta.png"))
+```
+
 ```{julia}
 λ_delta = [0.1,0.1,2.5]
 λ = [0.1,0.25,0.25]
@@ -802,6 +808,29 @@ if _regen_all_digits
 end
 ```
 
+#### Energy Delta (not in paper)
+
+```{julia}
+_regen_all_digits = true
+if _regen_all_digits
+    function plot_all_digits(rng=123;verbose=true,img_height=180,kwargs...)
+        plts = []
+        for i in 0:9
+            for j in 0:9
+                @info "Generating counterfactual for $(i) -> $(j)"
+                plt = plot_mnist(i,j;kwargs...,rng=rng, img_height=img_height)
+                !verbose || display(plt)
+                plts = [plts..., plt]
+            end
+        end
+        plt = Plots.plot(plts...; size=(img_height*10,img_height*10), layout=(10,10), dpi=300)
+        return plt
+    end
+    plt = plot_all_digits(generator=gen_delta)
+    savefig(plt, joinpath(output_images_path, "mnist_eccco_all_digits-delta.png"))
+end
+```
+
 ## Benchmark
 
 ```{julia}
@@ -882,4 +911,4 @@ plt = draw(
 )   
 display(plt)
 save(joinpath(output_images_path, "mnist_benchmark.png"), plt, px_per_unit=5)
-```
\ No newline at end of file
+```
diff --git a/src/penalties.jl b/src/penalties.jl
index 7ea7225f20d9056fcdb9287ff3a5f7157e51e0f3..cba61638c58e6007ce3ca35590fcf1eff5699c03 100644
--- a/src/penalties.jl
+++ b/src/penalties.jl
@@ -72,13 +72,13 @@ function energy_delta(
         end
     end
 
-    xtarget = conditional_samples[1]                    # conditional samples
-    x = CounterfactualExplanations.decode_state(ce)     # current state
+    xgenerated = conditional_samples[1]                         # conditional samples
+    xproposed = CounterfactualExplanations.decode_state(ce)     # current state
     t = get_target_index(ce.data.y_levels, ce.target)
-    E(x) = -logits(ce.M, x)[t,:]                # negative logits for target class
-    _loss = E(x) .- E(xtarget)
+    E(x) = -logits(ce.M, x)[t,:]                                # negative logits for target class
+    _loss = E(xproposed) .- E(xgenerated)
 
-    _loss = reduce((x, y) -> x + y, _loss) / n          # aggregate over samples
+    _loss = reduce((x, y) -> x + y, _loss) / n                  # aggregate over samples
 
     if return_conditionals
         return conditional_samples[1]