@@ -164,8 +164,8 @@ end
 # Hyper:
-_retrain = true
-_regen = true
+_retrain = false
+_regen = false
 # Data:
 n_obs = 10000
@@ -357,6 +357,17 @@ model_performance
 ### Different Models
+function plot_mnist(ce; size=(img_height, img_height), kwrgs...)
+    x = CounterfactualExplanations.counterfactual(ce)
+    phat = target_probs(ce)
+    plt = Plots.plot(
+        convert2image(MNIST, reshape(x,28,28));
+        axis=([], false), 
+        size=size,
+        kwrgs...,
+    )
 function _plot_eccco_mnist(
     x::Union{AbstractArray, Int}=x_factual, target::Int=target;
@@ -372,6 +383,7 @@ function _plot_eccco_mnist(
     plot_factual::Bool = false,
     test_data::Bool = false,
+    use_energy_delta::Bool = false,
@@ -392,6 +404,7 @@ function _plot_eccco_mnist(
+            use_energy_delta=use_energy_delta,
@@ -458,6 +471,27 @@ display(plt)
 savefig(plt, joinpath(output_images_path, "mnist_eccco.png"))
+#### Energy Delta (not in paper)
+plt, gen_delta, ces = _plot_eccco_mnist(λ = [0.1,0.1,3.0], use_energy_delta=true)
+savefig(plt, joinpath(output_images_path, "mnist_eccco_energy_delta.png"))
+λ_delta = [0.1,0.1,2.5]
+λ = [0.1,0.25,0.25]
+plts = []
+for i in 0:9
+    plt, _, _ = _plot_eccco_mnist(x_factual, i; λ = λ, plot_title="Distance")
+    plt_delta, _, _ = _plot_eccco_mnist(x_factual, i; λ = λ_delta, use_energy_delta=true, plot_title="Energy Delta")
+    plt = Plots.plot(plt, plt_delta; size=(img_height*2,img_height), layout=(1,2))
+    display(plt)
+    push!(plts, plt)
 #### Additional Models (not in paper)
@@ -774,6 +808,29 @@ if _regen_all_digits
+#### Energy Delta (not in paper)
+_regen_all_digits = true
+if _regen_all_digits
+    function plot_all_digits(rng=123;verbose=true,img_height=180,kwargs...)
+        plts = []
+        for i in 0:9
+            for j in 0:9
+                @info "Generating counterfactual for $(i) -> $(j)"
+                plt = plot_mnist(i,j;kwargs...,rng=rng, img_height=img_height)
+                !verbose || display(plt)
+                plts = [plts..., plt]
+            end
+        end
+        plt = Plots.plot(plts...; size=(img_height*10,img_height*10), layout=(10,10), dpi=300)
+        return plt
+    end
+    plt = plot_all_digits(generator=gen_delta)
+    savefig(plt, joinpath(output_images_path, "mnist_eccco_all_digits-delta.png"))
 ## Benchmark
@@ -854,4 +911,4 @@ plt = draw(
 save(joinpath(output_images_path, "mnist_benchmark.png"), plt, px_per_unit=5)
\ No newline at end of file
+# Linearly Separable Data
+# Hyper:
+_retrain = false
+# Data:
+test_size = 0.2
+n_obs = Int(1000 / (1.0 - test_size))
+counterfactual_data, test_data = train_test_split(
+    load_blobs(n_obs; cluster_std=0.1, center_box=(-1. => 1.)); 
+    test_size=test_size
+X, y = CounterfactualExplanations.DataPreprocessing.unpack_data(counterfactual_data)
+X = table(permutedims(X))
+labels = counterfactual_data.output_encoder.labels
+input_dim, n_obs = size(counterfactual_data.X)
+output_dim = length(unique(labels))
+First, let's create a couple of image classifier architectures:
+# Model parameters:
+epochs = 100
+batch_size = minimum([Int(round(n_obs/10)), 128])
+n_hidden = 16
+activation = Flux.swish
+builder = MLJFlux.MLP(
+    hidden=(n_hidden, n_hidden, n_hidden), 
+    σ=Flux.swish
+n_ens = 5                                   # number of models in ensemble
+_loss = Flux.Losses.crossentropy       # loss function
+_finaliser = Flux.softmax                   # finaliser function
+# JEM parameters:
+𝒟x = Normal()
+𝒟y = Categorical(ones(output_dim) ./ output_dim)
+sampler = ConditionalSampler(
+    𝒟x, 𝒟y, 
+    input_size=(input_dim,), 
+    batch_size=50,
+α = [1.0,1.0,1e-1]      # penalty strengths
+# Joint Energy Model:
+model = JointEnergyClassifier(
+    sampler;
+    builder=builder,
+    epochs=epochs,
+    batch_size=batch_size,
+    finaliser=_finaliser,
+    loss=_loss,
+    jem_training_params=(
+        α=α,verbosity=10,
+    ),
+    sampling_steps=30,
+conf_model = conformal_model(model; method=:simple_inductive, coverage=0.95)
+mach = machine(conf_model, X, labels)
+@info "Begin training model."
+@info "Finished training model."
+M = ECCCo.ConformalModel(mach.model, mach.fitresult)
+λ₁ = 0.25
+λ₂ = 0.75
+λ₃ = 0.75
+Λ = [λ₁, λ₂, λ₃]
+opt = Flux.Optimise.Descent(0.01)
+use_class_loss = false
+# Benchmark generators:
+generator_dict = Dict(
+    "ECCCo" => ECCCoGenerator(λ=Λ, opt=opt, use_class_loss=use_class_loss),
+    "ECCCo (energy delta)" => ECCCoGenerator(λ=Λ, opt=opt, use_class_loss=use_class_loss, use_energy_delta=true),
+X = X isa Matrix ? X : Float32.(permutedims(matrix(X)))
+factual_label =  levels(labels)[1]
+x_factual = reshape(X[:,rand(findall(predict_label(M, counterfactual_data).==factual_label))],input_dim,1)
+target =  levels(labels)[2]
+factual = predict_label(M, counterfactual_data, x_factual)[1]
+ces = Dict{Any,Any}()
+plts = []
+for (name, generator) in generator_dict
+    ce = generate_counterfactual(
+        x_factual, target, counterfactual_data, M, generator;
+        initialization=:identity, 
+        converge_when=:generator_conditions,
+    )
+    plt = Plots.plot(
+        ce, title=name, alpha=0.2, 
+        cbar=false, 
+    )
+    if contains(name, "ECCCo")
+        _X = distance_from_energy(ce, return_conditionals=true)
+        Plots.scatter!(
+            _X[1,:],_X[2,:], color=:purple, shape=:star5, 
+            ms=10, label="x̂|$target", alpha=0.5
+        )
+    end
+    push!(plts, plt)
+    ces[name] = ce
+plt = Plots.plot(plts..., size=(800,350))
\ No newline at end of file
@@ -25,6 +25,7 @@ function ECCCoGenerator(;
+    use_energy_delta::Bool=false,
@@ -47,7 +48,11 @@ function ECCCoGenerator(;
     # Energy penalty
     function _energy_penalty(ce::AbstractCounterfactualExplanation)
-        return ECCCo.distance_from_energy(ce; n=nsamples, nmin=nmin, kwargs...)
+        if use_energy_delta
+            return ECCCo.energy_delta(ce; n=nsamples, nmin=nmin, kwargs...)
+        else
+            return ECCCo.distance_from_energy(ce; n=nsamples, nmin=nmin, kwargs...)
+        end
     _penalties = [Objectives.distance_l1, _set_size_penalty, _energy_penalty]
@@ -1,4 +1,5 @@
 using ChainRules: ignore_derivatives
+using CounterfactualExplanations: get_target_index
 using Distances
 using Flux
 using LinearAlgebra: norm
@@ -38,6 +39,54 @@ function set_size_penalty(
+function energy_delta(
+    ce::AbstractCounterfactualExplanation;
+    n::Int=50, niter=500, from_buffer=true, agg=mean,
+    choose_lowest_energy=true,
+    choose_random=false,
+    nmin::Int=25,
+    return_conditionals=false,
+    kwargs...
+    _loss = 0.0
+    nmin = minimum([nmin, n])
+    @assert choose_lowest_energy ⊻ choose_random || !choose_lowest_energy && !choose_random "Must choose either lowest energy or random samples or neither."
+    conditional_samples = []
+    ignore_derivatives() do
+        _dict = ce.params
+        if !(:energy_sampler ∈ collect(keys(_dict)))
+            _dict[:energy_sampler] = ECCCo.EnergySampler(ce; niter=niter, nsamples=n, kwargs...)
+        end
+        eng_sampler = _dict[:energy_sampler]
+        if choose_lowest_energy
+            nmin = minimum([nmin, size(eng_sampler.buffer)[end]])
+            xmin = ECCCo.get_lowest_energy_sample(eng_sampler; n=nmin)
+            push!(conditional_samples, xmin)
+        elseif choose_random
+            push!(conditional_samples, rand(eng_sampler, n; from_buffer=from_buffer))
+        else
+            push!(conditional_samples, eng_sampler.buffer)
+        end
+    end
+    xgenerated = conditional_samples[1]                         # conditional samples
+    xproposed = CounterfactualExplanations.decode_state(ce)     # current state
+    t = get_target_index(ce.data.y_levels, ce.target)
+    E(x) = -logits(ce.M, x)[t,:]                                # negative logits for target class
+    _loss = E(xproposed) .- E(xgenerated)
+    _loss = reduce((x, y) -> x + y, _loss) / n                  # aggregate over samples
+    if return_conditionals
+        return conditional_samples[1]
+    end
+    return _loss
 function distance_from_energy(
     n::Int=50, niter=500, from_buffer=true, agg=mean,