diff --git a/notebooks/Manifest.toml b/notebooks/Manifest.toml
index 400facb7be5d41186be3476bfe291d7420d08ee1..567a5edf16ff6947804306d6569a8706f2fd81f7 100644
--- a/notebooks/Manifest.toml
+++ b/notebooks/Manifest.toml
@@ -17,9 +17,9 @@ version = "0.4.4"
 
 [[deps.Accessors]]
 deps = ["Compat", "CompositionsBase", "ConstructionBase", "Dates", "InverseFunctions", "LinearAlgebra", "MacroTools", "Requires", "StaticArrays", "Test"]
-git-tree-sha1 = "beabc31fa319f9de4d16372bff31b4801e43d32c"
+git-tree-sha1 = "c7dddee3f32ceac12abd9a21cd0c4cb489f230d2"
 uuid = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
-version = "0.1.28"
+version = "0.1.29"
 
 [[deps.Adapt]]
 deps = ["LinearAlgebra", "Requires"]
@@ -1052,7 +1052,7 @@ uuid = "0f8b85d8-7281-11e9-16c2-39a750bddbf1"
 version = "1.12.0"
 
 [[deps.JointEnergyModels]]
-deps = ["ChainRulesCore", "ComputationalResources", "Distributions", "Flux", "MLJFlux", "MLJModelInterface", "MLUtils", "PkgTemplates", "ProgressMeter", "Random", "StatsBase", "Tables"]
+deps = ["ChainRulesCore", "ComputationalResources", "Distributions", "Flux", "MLJFlux", "MLJModelInterface", "PkgTemplates", "ProgressMeter", "Random", "StatsBase", "Tables"]
 path = "../../JointEnergyModels.jl"
 uuid = "48c56d24-211d-4463-bbc0-7a701b291131"
 version = "0.1.0"
diff --git a/notebooks/mnist.qmd b/notebooks/mnist.qmd
index 8a7f179f1dc24a1e55975659fc4c2a5b45d6ffa6..0939ebd0ade4958b1a52644ae795d66309a36c21 100644
--- a/notebooks/mnist.qmd
+++ b/notebooks/mnist.qmd
@@ -94,7 +94,7 @@ mlp_ens = EnsembleModel(model=mlp, n=50)
 
 ```{julia}
 cov = .95
-conf_model = conformal_model(jem; method=:adaptive_inductive, coverage=cov)
+conf_model = conformal_model(jem; method=:simple_inductive, coverage=cov)
 mach = machine(conf_model, X, labels)
 fit!(mach)
 M = ECCCo.ConformalModel(mach.model, mach.fitresult)
@@ -130,6 +130,15 @@ f1 = CounterfactualExplanations.Models.model_evaluation(M, test_data)
 println("F1 score (test): $(round(f1,digits=3))")
 ```
 
+```{julia}
+𝒟x = Uniform(-1,1)
+𝒟y = Categorical(ones(K) ./ K)
+sampler = UnconditionalSampler(𝒟x; input_size=(D,))
+conditional_sampler = ConditionalSampler(𝒟x, 𝒟y; input_size=(D,))
+opt = ImproperSGLD()
+n_iter = 100
+```
+
 ```{julia}
 # Random.seed!(1234)
 
@@ -141,15 +150,17 @@ factual = predict_label(M, counterfactual_data, x)[1]
 γ = 0.5
 T = 100
 
+η=1.0
+
 # Generate counterfactual using generic generator:
-generator = GenericGenerator(opt=Flux.Optimise.Adam(),)
+generator = GenericGenerator()
 ce_wachter = generate_counterfactual(
     x, target, counterfactual_data, M, generator; 
     decision_threshold=γ, max_iter=T,
     initialization=:identity,
 )
 
-generator = GreedyGenerator(η=1.0)
+generator = GreedyGenerator(η=η)
 ce_jsma = generate_counterfactual(
     x, target, counterfactual_data, M, generator; 
     decision_threshold=γ, max_iter=T,
@@ -157,14 +168,13 @@ ce_jsma = generate_counterfactual(
 )
 
 # ECCCo:
-λ=[0.0,1.0]
-temp=0.5
+λ=[0.0,1.0,10.0]
+temp=0.01
 
-# Generate counterfactual using CCE generator:
-generator = CCEGenerator(
+# Generate counterfactual using ECCCo generator:
+generator = ECCCoGenerator(
     λ=λ, 
     temp=temp, 
-    opt=Flux.Optimise.Adam(),
 )
 ce_conformal = generate_counterfactual(
     x, target, counterfactual_data, M, generator; 
@@ -173,11 +183,11 @@ ce_conformal = generate_counterfactual(
     converge_when=:generator_conditions,
 )
 
-# Generate counterfactual using CCE generator:
-generator = CCEGenerator(
+# Generate counterfactual using ECCCo generator:
+generator = ECCCoGenerator(
     λ=λ, 
     temp=temp, 
-    opt=CounterfactualExplanations.Generators.JSMADescent(η=1.0),
+    opt=CounterfactualExplanations.Generators.JSMADescent(η=η),
 )
 ce_conformal_jsma = generate_counterfactual(
     x, target, counterfactual_data, M, generator; 
@@ -196,7 +206,7 @@ p1 = Plots.plot(
 plts = [p1]
 
 ces = [ce_wachter, ce_conformal, ce_jsma, ce_conformal_jsma]
-_names = ["Wachter", "CCE", "JSMA", "CCE-JSMA"]
+_names = ["Wachter", "ECCCo", "JSMA", "ECCCo-JSMA"]
 for x in zip(ces, _names)
     ce, _name = (x[1],x[2])
     x = CounterfactualExplanations.counterfactual(ce)
@@ -212,7 +222,7 @@ for x in zip(ces, _names)
 end
 plt = Plots.plot(plts...; size=(img_height*length(plts),img_height), layout=(1,length(plts)))
 display(plt)
-savefig(plt, joinpath(www_path, "cce_mnist.png"))
+savefig(plt, joinpath(www_path, "eccco_mnist.png"))
 ```
 
 ```{julia}
@@ -226,8 +236,6 @@ factual = predict_label(M, counterfactual_data, x)[1]
 γ = 0.5
 T = 100
 
-η=0.1
-
 # Generate counterfactual using generic generator:
 generator = GenericGenerator(opt=Flux.Optimise.Adam(),)
 ce_wachter = generate_counterfactual(
@@ -236,7 +244,7 @@ ce_wachter = generate_counterfactual(
     initialization=:identity,
 )
 
-generator = GreedyGenerator(η=η)
+generator = GreedyGenerator(η=1.0)
 ce_jsma = generate_counterfactual(
     x, target, counterfactual_data, M, generator; 
     decision_threshold=γ, max_iter=T,
@@ -244,11 +252,11 @@ ce_jsma = generate_counterfactual(
 )
 
 # ECCCo:
-λ=[0.0,0.0,10.0]
+λ=[0.0,1.0]
 temp=0.5
 
-# Generate counterfactual using ECCCo generator:
-generator = ECCCoGenerator(
+# Generate counterfactual using CCE generator:
+generator = CCEGenerator(
     λ=λ, 
     temp=temp, 
     opt=Flux.Optimise.Adam(),
@@ -260,11 +268,11 @@ ce_conformal = generate_counterfactual(
     converge_when=:generator_conditions,
 )
 
-# Generate counterfactual using ECCCo generator:
-generator = ECCCoGenerator(
+# Generate counterfactual using CCE generator:
+generator = CCEGenerator(
     λ=λ, 
     temp=temp, 
-    opt=CounterfactualExplanations.Generators.JSMADescent(η=η),
+    opt=CounterfactualExplanations.Generators.JSMADescent(η=1.0),
 )
 ce_conformal_jsma = generate_counterfactual(
     x, target, counterfactual_data, M, generator; 
@@ -283,7 +291,7 @@ p1 = Plots.plot(
 plts = [p1]
 
 ces = [ce_wachter, ce_conformal, ce_jsma, ce_conformal_jsma]
-_names = ["Wachter", "ECCCo", "JSMA", "ECCCo-JSMA"]
+_names = ["Wachter", "CCE", "JSMA", "CCE-JSMA"]
 for x in zip(ces, _names)
     ce, _name = (x[1],x[2])
     x = CounterfactualExplanations.counterfactual(ce)
@@ -299,7 +307,7 @@ for x in zip(ces, _names)
 end
 plt = Plots.plot(plts...; size=(img_height*length(plts),img_height), layout=(1,length(plts)))
 display(plt)
-savefig(plt, joinpath(www_path, "eccco_mnist.png"))
+savefig(plt, joinpath(www_path, "cce_mnist.png"))
 ```
 
 ```{julia}
diff --git a/paper/paper.pdf b/paper/paper.pdf
index 222cec3cdb4a1bfb1966a554996d478b458fc381..eddb381562b5e51c254fb4629f64e5cc0707647a 100644
Binary files a/paper/paper.pdf and b/paper/paper.pdf differ
diff --git a/src/penalties.jl b/src/penalties.jl
index 95c96b4109cba3bf5f40d4c8f7ae03e44bfefceb..e9af36f125a7ddc0df9005d7b35bd70e9eb9d4c9 100644
--- a/src/penalties.jl
+++ b/src/penalties.jl
@@ -36,13 +36,13 @@ end
 
 function distance_from_energy(
     ce::AbstractCounterfactualExplanation;
-    n::Int=1, niter=200, from_buffer=true, agg=mean, kwargs...
+    n::Int=10, niter=200, from_buffer=true, agg=mean, kwargs...
 )
     conditional_samples = []
     ignore_derivatives() do
         _dict = ce.params
         if !(:energy_sampler ∈ collect(keys(_dict)))
-            _dict[:energy_sampler] = ECCCo.EnergySampler(ce; niter=niter, nsamples=1000, kwargs...)
+            _dict[:energy_sampler] = ECCCo.EnergySampler(ce; niter=niter, nsamples=n, kwargs...)
         end
         sampler = _dict[:energy_sampler]
         push!(conditional_samples, rand(sampler, n; from_buffer=from_buffer))
diff --git a/src/sampling.jl b/src/sampling.jl
index 50a170283a65f6e0ad64faae58084913898f0a95..023c3d87fd7a138fc21431f4fd603aa19ace771e 100644
--- a/src/sampling.jl
+++ b/src/sampling.jl
@@ -57,10 +57,8 @@ function EnergySampler(
     # Initiate:
     energy_sampler = EnergySampler(model, data, sampler, opt, nothing, nothing)
 
-    # Generate samples:
-    chain = model.model.model.jem.chain
-    rule = model.model.model.jem.sampling_rule
-    energy_sampler.sampler(chain, rule; niter=niter, n_samples=nsamples, y=yidx)
+    # Generate conditional samples:
+    generate_samples!(energy_sampler, nsamples, yidx; niter=niter)
 
     return energy_sampler
 end
@@ -86,19 +84,43 @@ function EnergySampler(
     return EnergySampler(model, data, y; kwrgs...)
 end
 
+"""
+    generate_samples(e::EnergySampler, n::Int, y::Int; niter::Int=100)
+
+Generates `n` samples from `EnergySampler` for conditioning value `y`.
+"""
+function generate_samples(e::EnergySampler, n::Int, y::Int; niter::Int=100)
+
+    # Generate samples:
+    chain = e.model.fitresult[1]
+    rule = e.opt
+    xsamples = e.sampler(chain, rule; niter=niter, n_samples=n, y=y)
+
+    return xsamples
+end
+
+"""
+    generate_samples!(e::EnergySampler, n::Int, y::Int; niter::Int=100)
+
+Generates `n` samples from `EnergySampler` for conditioning value `y`. Assigns samples and conditioning value to `EnergySampler`.
+"""
+function generate_samples!(e::EnergySampler, n::Int, y::Int; niter::Int=100)
+    e.buffer = generate_samples(e, n, y; niter=niter)
+    e.yidx = y
+end
+
 """
     Base.rand(sampler::EnergySampler, n::Int=100; retrain=false)
 
 Overloads the `rand` method to randomly draw `n` samples from `EnergySampler`.
 """
 function Base.rand(sampler::EnergySampler, n::Int=100; from_buffer=true, niter::Int=100)
-    ntotal = size(sampler.sampler.buffer)[end]
+    ntotal = size(sampler.buffer, 2)
     idx = rand(1:ntotal, n)
     if from_buffer
-        X = sampler.sampler.buffer[:, idx]
+        X = sampler.buffer[:, idx]
     else
-        chain = sampler.model.fitresult[1]
-        X = sampler.sampler(chain, sampler.opt; niter=niter, n_samples=n, y=sampler.yidx)
+        X = generate_samples(sampler, n, sampler.yidx; niter=niter)
     end
     return X
 end
diff --git a/www/eccco_mnist.png b/www/eccco_mnist.png
index 81cecbe4d0f947bf55ef08cd7f2db35f318337c2..8c8f3a1869c09a98aafb4cfdb94a5b4f09bf041c 100644
Binary files a/www/eccco_mnist.png and b/www/eccco_mnist.png differ