diff --git a/artifacts/results/mnist_architectures.jls b/artifacts/results/mnist_architectures.jls
index 4054e47c244cf42615793282663b50e07abf4755..c227bf2cec8ce4c773d4a25f0e73440f03486b4c 100644
Binary files a/artifacts/results/mnist_architectures.jls and b/artifacts/results/mnist_architectures.jls differ
diff --git a/artifacts/results/mnist_vae.jls b/artifacts/results/mnist_vae.jls
index 1a74f76e4958d83d3fc08c1fe1c247292e07f297..3537f90018d2104bc168f3b6d848f7030487fcaa 100644
Binary files a/artifacts/results/mnist_vae.jls and b/artifacts/results/mnist_vae.jls differ
diff --git a/artifacts/results/mnist_vae_weak.jls b/artifacts/results/mnist_vae_weak.jls
index 2100c22bdb104119ce7eb024a9d24da877da7105..96a8d445d9ec1ad815bf56539ae121d82f7e742c 100644
Binary files a/artifacts/results/mnist_vae_weak.jls and b/artifacts/results/mnist_vae_weak.jls differ
diff --git a/notebooks/.CondaPkg/env/conda-meta/history b/notebooks/.CondaPkg/env/conda-meta/history
index dd869616bef06f2eeffd268caa0a34b25e3af339..6c7683b9e57e48bee3f8f60a08c00c9a8f1c2d8b 100644
--- a/notebooks/.CondaPkg/env/conda-meta/history
+++ b/notebooks/.CondaPkg/env/conda-meta/history
@@ -1,4 +1,4 @@
-==> 2023-08-08 15:03:23 <==
+==> 2023-08-10 11:22:12 <==
 # cmd: /Users/FA31DU/.julia/artifacts/6ecf04294c7f327e02e84972f34835649a5eb35e/bin/micromamba -r /Users/FA31DU/.julia/scratchspaces/0b3b1443-0f03-428d-bdfb-f27f9c1191ea/root create -y -p /Users/FA31DU/code/ECCCo.jl/notebooks/.CondaPkg/env --override-channels --no-channel-priority numpy[version='*'] pip[version='>=22.0.0'] python[version='>=3.7,<4',channel='conda-forge',build='*cpython*'] -c conda-forge
 # conda version: 3.8.0
 +https://conda.anaconda.org/conda-forge/osx-64::xz-5.2.6-h775f41a_0
diff --git a/notebooks/.CondaPkg/meta b/notebooks/.CondaPkg/meta
index 1b11b47e6fbbf80a0ee54e97be26195177f40138..6b7ba5c59ae065f773611edc2fc41659836fb772 100644
Binary files a/notebooks/.CondaPkg/meta and b/notebooks/.CondaPkg/meta differ
diff --git a/notebooks/mnist.qmd b/notebooks/mnist.qmd
index 2d88f306a27ef313f91998472fa8b4fa17461300..91b18b157d88c906f57a626f42b6d10e177cf814 100644
--- a/notebooks/mnist.qmd
+++ b/notebooks/mnist.qmd
@@ -164,8 +164,8 @@ end
 
 ```{julia}
 # Hyper:
-_retrain = true
-_regen = true
+_retrain = false
+_regen = false
 
 # Data:
 n_obs = 10000
@@ -357,6 +357,17 @@ model_performance
 ### Different Models
 
 ```{julia}
+function plot_mnist(ce; size=(img_height, img_height), kwrgs...)
+    x = CounterfactualExplanations.counterfactual(ce)
+    phat = target_probs(ce)
+    plt = Plots.plot(
+        convert2image(MNIST, reshape(x,28,28));
+        axis=([], false), 
+        size=size,
+        kwrgs...,
+    )
+end
+
 function _plot_eccco_mnist(
     x::Union{AbstractArray, Int}=x_factual, target::Int=target;
     λ=[0.1,0.25,0.25],
@@ -372,6 +383,7 @@ function _plot_eccco_mnist(
     plot_factual::Bool = false,
     generator::Union{Nothing,CounterfactualExplanations.AbstractGenerator}=nothing,
     test_data::Bool = false,
+    use_energy_delta::Bool = false,
     kwrgs...,
 )
 
@@ -392,6 +404,7 @@ function _plot_eccco_mnist(
             use_class_loss=use_class_loss,
             nsamples=10,
             nmin=10,
+            use_energy_delta=use_energy_delta,
         )
     end
 
@@ -458,6 +471,27 @@ display(plt)
 savefig(plt, joinpath(output_images_path, "mnist_eccco.png"))
 ```
 
+#### Energy Delta (not in paper)
+
+```{julia}
+plt, gen_delta, ces = _plot_eccco_mnist(λ = [0.1,0.1,3.0], use_energy_delta=true)
+display(plt)
+savefig(plt, joinpath(output_images_path, "mnist_eccco_energy_delta.png"))
+```
+
+```{julia}
+λ_delta = [0.1,0.1,2.5]
+λ = [0.1,0.25,0.25]
+plts = []
+for i in 0:9
+    plt, _, _ = _plot_eccco_mnist(x_factual, i; λ = λ, plot_title="Distance")
+    plt_delta, _, _ = _plot_eccco_mnist(x_factual, i; λ = λ_delta, use_energy_delta=true, plot_title="Energy Delta")
+    plt = Plots.plot(plt, plt_delta; size=(img_height*2,img_height), layout=(1,2))
+    display(plt)
+    push!(plts, plt)
+end
+```
+
 #### Additional Models (not in paper)
 
 LeNet-5:
@@ -774,6 +808,29 @@ if _regen_all_digits
 end
 ```
 
+#### Energy Delta (not in paper)
+
+```{julia}
+_regen_all_digits = true
+if _regen_all_digits
+    function plot_all_digits(rng=123;verbose=true,img_height=180,kwargs...)
+        plts = []
+        for i in 0:9
+            for j in 0:9
+                @info "Generating counterfactual for $(i) -> $(j)"
+                plt = plot_mnist(i,j;kwargs...,rng=rng, img_height=img_height)
+                !verbose || display(plt)
+                plts = [plts..., plt]
+            end
+        end
+        plt = Plots.plot(plts...; size=(img_height*10,img_height*10), layout=(10,10), dpi=300)
+        return plt
+    end
+    plt = plot_all_digits(generator=gen_delta)
+    savefig(plt, joinpath(output_images_path, "mnist_eccco_all_digits-delta.png"))
+end
+```
+
 ## Benchmark
 
 ```{julia}
@@ -854,4 +911,4 @@ plt = draw(
 )   
 display(plt)
 save(joinpath(output_images_path, "mnist_benchmark.png"), plt, px_per_unit=5)
-```
\ No newline at end of file
+```
diff --git a/notebooks/prototyping.qmd b/notebooks/prototyping.qmd
new file mode 100644
index 0000000000000000000000000000000000000000..7b61b6f8d8f7385e8c60ef52d1a7fb9d5dfb736f
--- /dev/null
+++ b/notebooks/prototyping.qmd
@@ -0,0 +1,130 @@
+```{julia}
+include("$(pwd())/notebooks/setup.jl")
+eval(setup_notebooks)
+```
+
+# Linearly Separable Data
+
+```{julia}
+# Hyper:
+_retrain = false
+
+# Data:
+test_size = 0.2
+n_obs = Int(1000 / (1.0 - test_size))
+counterfactual_data, test_data = train_test_split(
+    load_blobs(n_obs; cluster_std=0.1, center_box=(-1. => 1.)); 
+    test_size=test_size
+)
+X, y = CounterfactualExplanations.DataPreprocessing.unpack_data(counterfactual_data)
+X = table(permutedims(X))
+labels = counterfactual_data.output_encoder.labels
+input_dim, n_obs = size(counterfactual_data.X)
+output_dim = length(unique(labels))
+```
+
+First, let's create a couple of image classifier architectures:
+
+```{julia}
+# Model parameters:
+epochs = 100
+batch_size = minimum([Int(round(n_obs/10)), 128])
+n_hidden = 16
+activation = Flux.swish
+builder = MLJFlux.MLP(
+    hidden=(n_hidden, n_hidden, n_hidden), 
+    σ=Flux.swish
+)
+n_ens = 5                                   # number of models in ensemble
+_loss = Flux.Losses.crossentropy       # loss function
+_finaliser = Flux.softmax                   # finaliser function
+```
+
+```{julia}
+# JEM parameters:
+𝒟x = Normal()
+𝒟y = Categorical(ones(output_dim) ./ output_dim)
+sampler = ConditionalSampler(
+    𝒟x, 𝒟y, 
+    input_size=(input_dim,), 
+    batch_size=50,
+)
+α = [1.0,1.0,1e-1]      # penalty strengths
+```
+
+
+```{julia}
+# Joint Energy Model:
+model = JointEnergyClassifier(
+    sampler;
+    builder=builder,
+    epochs=epochs,
+    batch_size=batch_size,
+    finaliser=_finaliser,
+    loss=_loss,
+    jem_training_params=(
+        α=α,verbosity=10,
+    ),
+    sampling_steps=30,
+)
+```
+
+```{julia}
+conf_model = conformal_model(model; method=:simple_inductive, coverage=0.95)
+mach = machine(conf_model, X, labels)
+@info "Begin training model."
+fit!(mach)
+@info "Finished training model."
+M = ECCCo.ConformalModel(mach.model, mach.fitresult)
+```
+
+```{julia}
+λ₁ = 0.25
+λ₂ = 0.75
+λ₃ = 0.75
+Λ = [λ₁, λ₂, λ₃]
+
+opt = Flux.Optimise.Descent(0.01)
+use_class_loss = false
+
+# Benchmark generators:
+generator_dict = Dict(
+    "ECCCo" => ECCCoGenerator(λ=Λ, opt=opt, use_class_loss=use_class_loss),
+    "ECCCo (energy delta)" => ECCCoGenerator(λ=Λ, opt=opt, use_class_loss=use_class_loss, use_energy_delta=true),
+)
+```
+
+```{julia}
+Random.seed!(2023)
+
+X = X isa Matrix ? X : Float32.(permutedims(matrix(X)))
+factual_label =  levels(labels)[1]
+x_factual = reshape(X[:,rand(findall(predict_label(M, counterfactual_data).==factual_label))],input_dim,1)
+target =  levels(labels)[2]
+factual = predict_label(M, counterfactual_data, x_factual)[1]
+
+ces = Dict{Any,Any}()
+plts = []
+for (name, generator) in generator_dict
+    ce = generate_counterfactual(
+        x_factual, target, counterfactual_data, M, generator;
+        initialization=:identity, 
+        converge_when=:generator_conditions,
+    )
+    plt = Plots.plot(
+        ce, title=name, alpha=0.2, 
+        cbar=false, 
+    )
+    if contains(name, "ECCCo")
+        _X = distance_from_energy(ce, return_conditionals=true)
+        Plots.scatter!(
+            _X[1,:],_X[2,:], color=:purple, shape=:star5, 
+            ms=10, label="x̂|$target", alpha=0.5
+        )
+    end
+    push!(plts, plt)
+    ces[name] = ce
+end
+plt = Plots.plot(plts..., size=(800,350))
+display(plt)
+```
\ No newline at end of file
diff --git a/src/generator.jl b/src/generator.jl
index c356fb1aa9c956e30609e943cb99074b785e5136..2119b54407a8bf749a8ea645cc06411860572dda 100644
--- a/src/generator.jl
+++ b/src/generator.jl
@@ -25,6 +25,7 @@ function ECCCoGenerator(;
     use_class_loss::Bool=false,
     nsamples::Int=50,
     nmin::Int=25,
+    use_energy_delta::Bool=false,
     kwargs...
 )
 
@@ -47,7 +48,11 @@ function ECCCoGenerator(;
 
     # Energy penalty
     function _energy_penalty(ce::AbstractCounterfactualExplanation)
-        return ECCCo.distance_from_energy(ce; n=nsamples, nmin=nmin, kwargs...)
+        if use_energy_delta
+            return ECCCo.energy_delta(ce; n=nsamples, nmin=nmin, kwargs...)
+        else
+            return ECCCo.distance_from_energy(ce; n=nsamples, nmin=nmin, kwargs...)
+        end
     end
 
     _penalties = [Objectives.distance_l1, _set_size_penalty, _energy_penalty]
diff --git a/src/penalties.jl b/src/penalties.jl
index d8d1b7f669c718ece9cf504b6ffa672787cfc5fe..cba61638c58e6007ce3ca35590fcf1eff5699c03 100644
--- a/src/penalties.jl
+++ b/src/penalties.jl
@@ -1,4 +1,5 @@
 using ChainRules: ignore_derivatives
+using CounterfactualExplanations: get_target_index
 using Distances
 using Flux
 using LinearAlgebra: norm
@@ -38,6 +39,54 @@ function set_size_penalty(
 
 end
 
+function energy_delta(
+    ce::AbstractCounterfactualExplanation;
+    n::Int=50, niter=500, from_buffer=true, agg=mean,
+    choose_lowest_energy=true,
+    choose_random=false,
+    nmin::Int=25,
+    return_conditionals=false,
+    kwargs...
+)
+
+    _loss = 0.0
+    nmin = minimum([nmin, n])
+
+    @assert choose_lowest_energy ⊻ choose_random || !choose_lowest_energy && !choose_random "Must choose either lowest energy or random samples or neither."
+
+    conditional_samples = []
+    ignore_derivatives() do
+        _dict = ce.params
+        if !(:energy_sampler ∈ collect(keys(_dict)))
+            _dict[:energy_sampler] = ECCCo.EnergySampler(ce; niter=niter, nsamples=n, kwargs...)
+        end
+        eng_sampler = _dict[:energy_sampler]
+        if choose_lowest_energy
+            nmin = minimum([nmin, size(eng_sampler.buffer)[end]])
+            xmin = ECCCo.get_lowest_energy_sample(eng_sampler; n=nmin)
+            push!(conditional_samples, xmin)
+        elseif choose_random
+            push!(conditional_samples, rand(eng_sampler, n; from_buffer=from_buffer))
+        else
+            push!(conditional_samples, eng_sampler.buffer)
+        end
+    end
+
+    xgenerated = conditional_samples[1]                         # conditional samples
+    xproposed = CounterfactualExplanations.decode_state(ce)     # current state
+    t = get_target_index(ce.data.y_levels, ce.target)
+    E(x) = -logits(ce.M, x)[t,:]                                # negative logits for target class
+    _loss = E(xproposed) .- E(xgenerated)
+
+    _loss = reduce((x, y) -> x + y, _loss) / n                  # aggregate over samples
+
+    if return_conditionals
+        return conditional_samples[1]
+    end
+    return _loss
+
+end
+
 function distance_from_energy(
     ce::AbstractCounterfactualExplanation;
     n::Int=50, niter=500, from_buffer=true, agg=mean,