diff --git a/artifacts/results/images/mnist_eccco.png b/artifacts/results/images/mnist_eccco.png
index 4d91b401d4b88aaeaae8660904fe687f128f74b3..e1a9d1ced77de0967ab201950f7c9a07317c183d 100644
Binary files a/artifacts/results/images/mnist_eccco.png and b/artifacts/results/images/mnist_eccco.png differ
diff --git a/artifacts/results/mnist_vae.jls b/artifacts/results/mnist_vae.jls
index b0e32960a27364c9cf284cd5d8409ec363dd5299..93409644f8fdc0c48c85e12c72288c9d285443f4 100644
Binary files a/artifacts/results/mnist_vae.jls and b/artifacts/results/mnist_vae.jls differ
diff --git a/artifacts/results/mnist_vae_weak.jls b/artifacts/results/mnist_vae_weak.jls
index d3db5f46b841781aeaa993c20653bae898dc0aad..95053840467a78ebab6998e996a42fdcca25bc42 100644
Binary files a/artifacts/results/mnist_vae_weak.jls and b/artifacts/results/mnist_vae_weak.jls differ
diff --git a/notebooks/mnist.qmd b/notebooks/mnist.qmd
index 82d559fddefc8f283e66c3366ec4ff6fb9e9f48b..cbb323106029ce5358ee74cf3ce99205e6aae973 100644
--- a/notebooks/mnist.qmd
+++ b/notebooks/mnist.qmd
@@ -148,9 +148,10 @@ _retrain = false
 _regen = false
 
 # Data:
-n_obs = 1000
+n_obs = 10000
 counterfactual_data = load_mnist(n_obs)
 counterfactual_data.X = pre_process.(counterfactual_data.X)
+counterfactual_data.generative_model = vae
 X, y = CounterfactualExplanations.DataPreprocessing.unpack_data(counterfactual_data)
 X = table(permutedims(X))
 x_factual = reshape(pre_process(x_factual, noise=0.0f0), input_dim, 1)
@@ -186,7 +187,7 @@ sampler = ConditionalSampler(
     input_size=(input_dim,), 
     batch_size=10,
 )
-α = [1.0,1.0,1e-1]      # penalty strengths
+α = [1.0,1.0,1e-2]      # penalty strengths
 ```
 
 ```{julia}
@@ -258,6 +259,7 @@ end
 
 ```{julia}
 # Plot generated samples:
+n_regen = 150
 if _regen 
     for (mod_name, mod) in model_dict
         if ECCCo._has_sampler(mod)
@@ -272,12 +274,11 @@ if _regen
         opt = ImproperSGLD()
         f(x) = logits(mod, x)
 
-        n_iter = 200
         _w = 1500
         plts = []
         neach = 10
         for i in 1:10
-            x = sampler(f, opt; niter=n_iter, n_samples=neach, y=i)
+            x = sampler(f, opt; niter=n_regen, n_samples=neach, y=i)
             plts_i = []
             for j in 1:size(x, 2)
                 xj = x[:,j]
@@ -321,7 +322,7 @@ model_performance
 ```{julia}
 function _plot_eccco_mnist(
     x::Union{AbstractArray, Int}=x_factual, target::Int=target;
-    λ=[0.1,0.1,0.1],
+    λ=[0.5,0.1,0.5],
     temp=0.1,η=0.01,
     plt_order = ["MLP", "MLP Ensemble", "JEM", "JEM Ensemble"],
     opt = Flux.Optimise.Adam(η),
@@ -509,7 +510,29 @@ measures = [
     CounterfactualExplanations.distance,
     ECCCo.distance_from_energy,
     ECCCo.distance_from_targets,
-    CounterfactualExplanations.validity,
-    CounterfactualExplanations.redudancy,
+    CounterfactualExplanations.Evaluation.validity,
+    CounterfactualExplanations.Evaluation.redundancy,
 ]
+
+bmk = benchmark(
+    counterfactual_data; 
+    models=model_dict, 
+    generators=generator_dict, 
+    measure=measures,
+    suppress_training=true, dataname="MNIST",
+    n_individuals=100,
+    initialization=:identity,
+)
+
+Serialization.serialize(joinpath(output_path, "mnist_benchmark.jls"), bmk)
+```
+
+
+```{julia}
+@chain bmk() begin
+    @group_by(dataname, generator, model, variable)
+    @summarize(mean=mean(value),sd=std(value))
+    @ungroup
+    @filter(variable == "distance_from_targets")
+end
 ```
\ No newline at end of file
diff --git a/paper/paper.pdf b/paper/paper.pdf
index b7ffb3b71de08967964800dacafa97d294ab1dfa..661c91d583d80ffd8c84daf9afdd5d5f7c4f5ede 100644
Binary files a/paper/paper.pdf and b/paper/paper.pdf differ
diff --git a/src/penalties.jl b/src/penalties.jl
index 92b007d06e219bf2bb3b69952d286619a0c5f8a7..820b2f3ba6874c61d4bb87badac676037f562b06 100644
--- a/src/penalties.jl
+++ b/src/penalties.jl
@@ -38,7 +38,7 @@ end
 
 function distance_from_energy(
     ce::AbstractCounterfactualExplanation;
-    n::Int=10, niter=100, from_buffer=true, agg=mean, kwargs...
+    n::Int=10, niter=60, from_buffer=true, agg=mean, kwargs...
 )
     conditional_samples = []
     ignore_derivatives() do
@@ -47,7 +47,7 @@ function distance_from_energy(
             _dict[:energy_sampler] = ECCCo.EnergySampler(ce; niter=niter, nsamples=n, kwargs...)
         end
         sampler = _dict[:energy_sampler]
-        push!(conditional_samples, rand(sampler, 100; from_buffer=from_buffer))
+        push!(conditional_samples, rand(sampler, n; from_buffer=from_buffer))
     end
     x′ = CounterfactualExplanations.counterfactual(ce)
     loss = map(eachslice(x′, dims=ndims(x′))) do x
@@ -64,16 +64,15 @@ end
 
 function distance_from_targets(
     ce::AbstractCounterfactualExplanation;
-    n::Int=1000, agg=mean
+    n::Int=100, agg=mean
 )
     target_idx = ce.data.output_encoder.labels .== ce.target
     target_samples = ce.data.X[:,target_idx] |>
         X -> X[:,rand(1:end,n)]
     x′ = CounterfactualExplanations.counterfactual(ce)
-    loss = map(eachslice(x′, dims=3)) do x
-        x = Matrix(x)
+    loss = map(eachslice(x′, dims=ndims(x′))) do x
         Δ = map(eachcol(target_samples)) do xsample
-            norm(x - xsample)
+            norm(x - xsample, 1)
         end
         return mean(Δ)
     end