diff --git a/Manifest.toml b/Manifest.toml
index 06b80abd0a09ec27e3d9f94fbf4dc29493113df6..9434115488ec05f1ed00e2d6098137a8214bf046 100644
--- a/Manifest.toml
+++ b/Manifest.toml
@@ -785,7 +785,7 @@ uuid = "0f8b85d8-7281-11e9-16c2-39a750bddbf1"
 version = "1.12.0"
 
 [[deps.JointEnergyModels]]
-deps = ["ChainRulesCore", "ComputationalResources", "Distributions", "Flux", "MLJFlux", "MLJModelInterface", "MLUtils", "PkgTemplates", "ProgressMeter", "Random", "StatsBase", "Tables"]
+deps = ["ChainRulesCore", "ComputationalResources", "Distributions", "Flux", "MLJFlux", "MLJModelInterface", "PkgTemplates", "ProgressMeter", "Random", "StatsBase", "Tables"]
 path = "../JointEnergyModels.jl"
 uuid = "48c56d24-211d-4463-bbc0-7a701b291131"
 version = "0.1.0"
diff --git a/notebooks/Manifest.toml b/notebooks/Manifest.toml
index 567a5edf16ff6947804306d6569a8706f2fd81f7..4c077ed9a0fb768eb3181835b8c94abd02cd8374 100644
--- a/notebooks/Manifest.toml
+++ b/notebooks/Manifest.toml
@@ -2,7 +2,7 @@
 
 julia_version = "1.8.5"
 manifest_format = "2.0"
-project_hash = "b4b125f21013c1ac841e2bb761cd8922630f9f03"
+project_hash = "53181f00a1b7318c04c79643b4b4b157e79df8f9"
 
 [[deps.AbstractFFTs]]
 deps = ["ChainRulesCore", "LinearAlgebra"]
diff --git a/notebooks/Project.toml b/notebooks/Project.toml
index 9d63873a5c251704aed1ab60c1256b1f9c4c8f71..180fcaf4d57cabfdde9e9f037538959c2d3262aa 100644
--- a/notebooks/Project.toml
+++ b/notebooks/Project.toml
@@ -1,6 +1,5 @@
 [deps]
 AlgebraOfGraphics = "cbdf2221-f076-402e-a563-3d30da359d67"
-ECCCo = "0232c203-4013-4b0d-ad96-43e3e11ac3bf"
 CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b"
 CairoMakie = "13f3f980-e62b-5c42-98c6-ff1f3baf88f0"
 CategoricalDistributions = "af321ab8-2d2e-40a6-b165-3d674595d28e"
@@ -8,7 +7,9 @@ Chain = "8be319e6-bccf-4806-a6f7-6fae938471bc"
 ConformalPrediction = "98bfc277-1877-43dc-819b-a3e38c30242f"
 CounterfactualExplanations = "2f13d31b-18db-44c1-bc43-ebaf2cff0be0"
 DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
+Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
 Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
+ECCCo = "0232c203-4013-4b0d-ad96-43e3e11ac3bf"
 Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
 Images = "916415d5-f1e6-5110-898d-aaa5f9f070e0"
 JointEnergyModels = "48c56d24-211d-4463-bbc0-7a701b291131"
diff --git a/notebooks/mnist.qmd b/notebooks/mnist.qmd
index 0939ebd0ade4958b1a52644ae795d66309a36c21..b4fb8b774215fbe9c035a9d9778bad5ae7ded2a4 100644
--- a/notebooks/mnist.qmd
+++ b/notebooks/mnist.qmd
@@ -32,7 +32,7 @@ First, let's create a couple of image classifier architectures:
 # Model parameters:
 epochs = 100
 batch_size = minimum([Int(round(n_obs/10)), 128])
-n_hidden = 50
+n_hidden = 128
 activation = Flux.swish
 builder = MLJFlux.@builder Flux.Chain(
 
@@ -71,7 +71,7 @@ mlp = NeuralNetworkClassifier(
 sampler = ConditionalSampler(
     𝒟x, 𝒟y, 
     input_size=(input_dim,), 
-    batch_size=10
+    batch_size=1
 )
 jem = JointEnergyClassifier(
     sampler;
@@ -131,24 +131,15 @@ println("F1 score (test): $(round(f1,digits=3))")
 ```
 
 ```{julia}
-𝒟x = Uniform(-1,1)
-𝒟y = Categorical(ones(K) ./ K)
-sampler = UnconditionalSampler(𝒟x; input_size=(D,))
-conditional_sampler = ConditionalSampler(𝒟x, 𝒟y; input_size=(D,))
-opt = ImproperSGLD()
-n_iter = 100
-```
-
-```{julia}
-# Random.seed!(1234)
+Random.seed!(1234)
 
 # Set up search:
-factual_label = 8
+factual_label = 9
 x = reshape(counterfactual_data.X[:,rand(findall(predict_label(M, counterfactual_data).==factual_label))],input_dim,1)
-target = 3
+target = 4
 factual = predict_label(M, counterfactual_data, x)[1]
 γ = 0.5
-T = 100
+T = 250
 
 η=1.0
 
@@ -168,8 +159,8 @@ ce_jsma = generate_counterfactual(
 )
 
 # ECCCo:
-λ=[0.0,1.0,10.0]
-temp=0.01
+λ=[0.1,0.1,0.1]
+temp=0.1
 
 # Generate counterfactual using ECCCo generator:
 generator = ECCCoGenerator(
diff --git a/src/penalties.jl b/src/penalties.jl
index e9af36f125a7ddc0df9005d7b35bd70e9eb9d4c9..ccd1c8a911e1136f82422b8ee4ca200b58b47b8b 100644
--- a/src/penalties.jl
+++ b/src/penalties.jl
@@ -1,4 +1,5 @@
 using ChainRules: ignore_derivatives
+using Distances
 using LinearAlgebra: norm
 using Statistics: mean
 
@@ -50,7 +51,8 @@ function distance_from_energy(
     x′ = CounterfactualExplanations.counterfactual(ce)
     loss = map(eachslice(x′, dims=ndims(x′))) do x
         Δ = map(eachcol(conditional_samples[1])) do xsample
-            norm(x - xsample)
+            # 1 .- (x'xsample)/(norm(x)*norm(xsample))
+            norm(x - xsample, 1)
         end
         return mean(Δ)
     end
diff --git a/www/eccco_mnist.png b/www/eccco_mnist.png
index 8c8f3a1869c09a98aafb4cfdb94a5b4f09bf041c..edd2100307beb05f2b55511561bad5a4df0ad9f2 100644
Binary files a/www/eccco_mnist.png and b/www/eccco_mnist.png differ