diff --git a/notebooks/fidelity.qmd b/notebooks/fidelity.qmd
index def8c2757a419438923f3367cd976c2ffaee6ca4..155bc957de66635ff7bbf61274e09d9eab79490b 100644
--- a/notebooks/fidelity.qmd
+++ b/notebooks/fidelity.qmd
@@ -15,10 +15,38 @@ using Plots
 
 # Fidelity Measures
 
+## Binary
+
+```{julia}
+# Setup
+counterfactual_data = load_linearly_separable()
+M = fit_model(counterfactual_data, :DeepEnsemble)
+target = 2
+factual = 1
+chosen = rand(findall(predict_label(M, counterfactual_data) .== factual))
+x = select_factual(counterfactual_data, chosen)
+
+# Search:
+generator = GenericGenerator(opt=Descent(0.01))
+ce = generate_counterfactual(x, target, counterfactual_data, M, generator)
+```
+
+```{julia}
+niter = 100
+nsamples = 100
+
+sampler = CCE.EnergySampler(ce;niter=niter, nsamples=100)
+Xgen = rand(sampler, nsamples)
+plt = plot(M, counterfactual_data, target=ce.target, xlims=(-5,5),ylims=(-5,5),cbar=false)
+scatter!(Xgen[1,:],Xgen[2,:],alpha=0.5,color=target,shape=:star,label="X|y=$target")
+```
+
+## Multi-Class
+
 ```{julia}
 # Setup
 counterfactual_data = load_multi_class()
-M = fit_model(counterfactual_data, :MLP)
+M = fit_model(counterfactual_data, :DeepEnsemble)
 target = 4
 factual = 2
 chosen = rand(findall(predict_label(M, counterfactual_data) .== factual))
@@ -45,7 +73,19 @@ p1 = plot(ce)
 
 
 ```{julia}
-@objective(generator, _ + 0.1distance_l2 + 100.0distance_from_energy)
+using CCE: distance_from_energy
+@objective(generator, _ + 0.1distance_l2 + 10.0distance_from_energy)
 ce = generate_counterfactual(x, target, counterfactual_data, M, generator)
 p2 = plot(ce)
+```
+
+
+```{julia}
+using CCE: distance_from_targets
+@objective(
+    generator, 
+    _ + 0.1distance_l2 + 1.0distance_from_energy + 10.0distance_from_targets
+)
+ce = generate_counterfactual(x, target, counterfactual_data, M, generator)
+p3 = plot(ce)
 ```
\ No newline at end of file
diff --git a/src/penalties.jl b/src/penalties.jl
index a3f35a9b3291f0919584e32ff01805ef22ca009d..5f85af406725ad258610b65a54ce8ecd371f13c2 100644
--- a/src/penalties.jl
+++ b/src/penalties.jl
@@ -59,4 +59,25 @@ function distance_from_energy(
 
     return loss
 
-end
\ No newline at end of file
+end
+
+function distance_from_targets(
+    counterfactual_explanation::AbstractCounterfactualExplanation;
+    n::Int=100, agg=mean
+)
+    target_samples = counterfactual_explanation.data.X |>
+        X -> X[:,rand(1:end,n)]
+    x′ = CounterfactualExplanations.counterfactual(counterfactual_explanation)
+    loss = map(eachslice(x′, dims=3)) do x
+        x = Matrix(x)
+        Δ = map(eachcol(target_samples)) do xsample
+            norm(x - xsample)
+        end
+        return mean(Δ)
+    end
+    loss = agg(loss)
+
+    return loss
+
+end
+
diff --git a/src/sampling.jl b/src/sampling.jl
index 2cf346ca70ac22e307080bf1f06c0e44738e8396..29ac4c9aedb8bcb409fd7f543cf9ceaf1516c394 100644
--- a/src/sampling.jl
+++ b/src/sampling.jl
@@ -1,8 +1,9 @@
 using CounterfactualExplanations
 using Distributions
+using Flux
 using JointEnergyModels
 
-(model::AbstractFittedModel)(x) = logits(model, x)
+(model::AbstractFittedModel)(x) = log.(CounterfactualExplanations.predict_proba(model, nothing, x))
 
 mutable struct EnergySampler
     ce::CounterfactualExplanation
@@ -28,7 +29,7 @@ function EnergySampler(
 
     # Fit:
     i = get_target_index(data.y_levels, ce.target)
-    buffer = sampler(model.model, opt, (size(data.X, 1), nsamples); niter=niter, y=i)
+    buffer = sampler(model, opt, (size(data.X, 1), nsamples); niter=niter, y=i)
 
     return EnergySampler(ce, sampler, opt, buffer)
 end