diff --git a/Manifest.toml b/Manifest.toml
index d16a0962307e2701b4cb33190dcace9d31955db4..4531e744c4232042cd1891c52c5235f2fdd795ca 100644
--- a/Manifest.toml
+++ b/Manifest.toml
@@ -304,7 +304,7 @@ uuid = "d38c429a-6771-53c6-b99e-75d170b6e991"
 version = "0.6.2"
 
 [[deps.CounterfactualExplanations]]
-deps = ["CSV", "CUDA", "CategoricalArrays", "DataFrames", "Flux", "LaplaceRedux", "LazyArtifacts", "LinearAlgebra", "MLDatasets", "MLJBase", "MLJModels", "MLUtils", "MultivariateStats", "NearestNeighborModels", "Parameters", "Plots", "ProgressMeter", "Random", "Serialization", "SliceMap", "SnoopPrecompile", "Statistics", "StatsBase", "Tables", "UMAP"]
+deps = ["CSV", "CUDA", "CategoricalArrays", "ChainRulesCore", "DataFrames", "Flux", "LaplaceRedux", "LazyArtifacts", "LinearAlgebra", "MLDatasets", "MLJBase", "MLJModels", "MLUtils", "MultivariateStats", "NearestNeighborModels", "Parameters", "Plots", "ProgressMeter", "Random", "Serialization", "SliceMap", "SnoopPrecompile", "Statistics", "StatsBase", "Tables", "UMAP"]
 path = "../CounterfactualExplanations.jl"
 uuid = "2f13d31b-18db-44c1-bc43-ebaf2cff0be0"
 version = "0.1.9"
diff --git a/docs/Manifest.toml b/docs/Manifest.toml
new file mode 100644
index 0000000000000000000000000000000000000000..bc62edc143f3749077aee8e4207d6539844d299a
--- /dev/null
+++ b/docs/Manifest.toml
@@ -0,0 +1,7 @@
+# This file is machine-generated - editing it directly is not advised
+
+julia_version = "1.8.5"
+manifest_format = "2.0"
+project_hash = "da39a3ee5e6b4b0d3255bfef95601890afd80709"
+
+[deps]
diff --git a/docs/Project.toml b/docs/Project.toml
new file mode 100644
index 0000000000000000000000000000000000000000..81648c0b16f0c3c6b1d2f26cd5994aa97a5b3d22
--- /dev/null
+++ b/docs/Project.toml
@@ -0,0 +1 @@
+[deps]
diff --git a/notebooks/Manifest.toml b/notebooks/Manifest.toml
index 6dd7ba4c2af64778888eecd9f46414a07292e102..a1c318aaa24efd6ebab22371e0af1b3394f2a008 100644
--- a/notebooks/Manifest.toml
+++ b/notebooks/Manifest.toml
@@ -383,7 +383,7 @@ uuid = "150eb455-5306-5404-9cee-2592286d6298"
 version = "0.6.2"
 
 [[deps.CounterfactualExplanations]]
-deps = ["CSV", "CUDA", "CategoricalArrays", "DataFrames", "Flux", "LaplaceRedux", "LazyArtifacts", "LinearAlgebra", "MLDatasets", "MLJBase", "MLJModels", "MLUtils", "MultivariateStats", "NearestNeighborModels", "Parameters", "Plots", "ProgressMeter", "Random", "Serialization", "SliceMap", "SnoopPrecompile", "Statistics", "StatsBase", "Tables", "UMAP"]
+deps = ["CSV", "CUDA", "CategoricalArrays", "ChainRulesCore", "DataFrames", "Flux", "LaplaceRedux", "LazyArtifacts", "LinearAlgebra", "MLDatasets", "MLJBase", "MLJModels", "MLUtils", "MultivariateStats", "NearestNeighborModels", "Parameters", "Plots", "ProgressMeter", "Random", "Serialization", "SliceMap", "SnoopPrecompile", "Statistics", "StatsBase", "Tables", "UMAP"]
 path = "../../CounterfactualExplanations.jl"
 uuid = "2f13d31b-18db-44c1-bc43-ebaf2cff0be0"
 version = "0.1.9"
diff --git a/notebooks/mnist.qmd b/notebooks/mnist.qmd
index 0716c8bcfe509a95540812fd7f3907628b7c6783..07a6fec5496bb6ca4d3590372b22901a3c834dbf 100644
--- a/notebooks/mnist.qmd
+++ b/notebooks/mnist.qmd
@@ -22,12 +22,14 @@ First, let's create a couple of image classifier architectures:
 ```{julia}
 # Model parameters:
 epochs = 100
-batch_size = minimum([Int(round(n_obs/10)), 100])
-n_hidden = 32
+batch_size = minimum([Int(round(n_obs/10)), 128])
+n_hidden = 200
 activation = Flux.relu
 builder = MLJFlux.@builder Flux.Chain(
     Dense(n_in, n_hidden),
     BatchNorm(n_hidden, activation),
+    Dense(n_hidden, n_hidden),
+    BatchNorm(n_hidden, activation),
     Dense(n_hidden, n_out),
 )
 # builder = MLJFlux.Short(n_hidden=n_hidden, dropout=0.2, σ=activation)
@@ -51,7 +53,11 @@ mlp = NeuralNetworkClassifier(
 # Joint Energy Model:
 𝒟x = Uniform(0,1)
 𝒟y = Categorical(ones(output_dim) ./ output_dim)
-sampler = ConditionalSampler(𝒟x, 𝒟y, input_size=(input_dim,), batch_size=batch_size)
+sampler = ConditionalSampler(
+    𝒟x, 𝒟y, 
+    input_size=(input_dim,), 
+    batch_size=10
+)
 jem = JointEnergyClassifier(
     sampler;
     builder=builder,
@@ -73,31 +79,33 @@ mlp_ens = EnsembleModel(model=mlp, n=5)
 
 ```{julia}
 cov = .90
-conf_model = conformal_model(jem; method=:simple_inductive, coverage=cov)
+conf_model = conformal_model(mlp; method=:adaptive_inductive, coverage=cov)
 mach = machine(conf_model, X, labels)
 fit!(mach)
 M = CCE.ConformalModel(mach.model, mach.fitresult)
 ```
 
 ```{julia}
-jem = mach.model.model.jem
-n_iter = 100
-_w = 1500
-plts = []
-neach = 10
-for i in 1:10
-    x = jem.sampler(jem.chain, jem.sampling_rule; niter=n_iter, n_samples=neach, y=i)
-    plts_i = []
-    for j in 1:size(x, 2)
-        xj = x[:,j]
-        xj = reshape(xj, (n_digits, n_digits))
-        plts_i = [plts_i..., Plots.heatmap(rotl90(xj), axis=nothing, cb=false)]
+if mach.model.model isa JointEnergyModels.JointEnergyClassifier
+    jem = mach.model.model.jem
+    n_iter = 100
+    _w = 1500
+    plts = []
+    neach = 10
+    for i in 1:10
+        x = jem.sampler(jem.chain, jem.sampling_rule; niter=n_iter, n_samples=neach, y=i)
+        plts_i = []
+        for j in 1:size(x, 2)
+            xj = x[:,j]
+            xj = reshape(xj, (n_digits, n_digits))
+            plts_i = [plts_i..., Plots.heatmap(rotl90(xj), axis=nothing, cb=false)]
+        end
+        plt = Plots.plot(plts_i..., size=(_w,0.10*_w), layout=(1,10))
+        plts = [plts..., plt]
     end
-    plt = Plots.plot(plts_i..., size=(_w,0.10*_w), layout=(1,10))
-    plts = [plts..., plt]
+    plt = Plots.plot(plts..., size=(_w,_w), layout=(10,1))
+    display(plt)
 end
-plt = Plots.plot(plts..., size=(_w,_w), layout=(10,1))
-display(plt)
 ```
 
 ```{julia}
@@ -110,12 +118,12 @@ println("F1 score (test): $(round(f1,digits=3))")
 Random.seed!(1234)
 
 # Set up search:
-factual_label = 9
+factual_label = 4
 x = reshape(counterfactual_data.X[:,rand(findall(predict_label(M, counterfactual_data).==factual_label))],input_dim,1)
-target = 4
+target = 9
 factual = predict_label(M, counterfactual_data, x)[1]
 γ = 0.5
-T = 1
+T = 100
 
 # Generate counterfactual using generic generator:
 generator = GenericGenerator()
@@ -125,11 +133,14 @@ ce_wachter = generate_counterfactual(
     initialization=:identity,
 )
 
+# CCE:
+λ=[0.0,1.0]
+temp=0.5
+
 # Generate counterfactual using CCE generator:
 generator = CCEGenerator(
-    λ=[0.0,1.0], 
-    temp=0.5, 
-    # opt=CounterfactualExplanations.Generators.JSMADescent(η=1.0),
+    λ=λ, 
+    temp=temp, 
 )
 ce_conformal = generate_counterfactual(
     x, target, counterfactual_data, M, generator; 
@@ -138,6 +149,19 @@ ce_conformal = generate_counterfactual(
     converge_when=:generator_conditions,
 )
 
+# Generate counterfactual using CCE generator:
+generator = CCEGenerator(
+    λ=λ, 
+    temp=temp, 
+    opt=CounterfactualExplanations.Generators.JSMADescent(η=1.0),
+)
+ce_conformal_jsma = generate_counterfactual(
+    x, target, counterfactual_data, M, generator; 
+    decision_threshold=γ, max_iter=T,
+    initialization=:identity,
+    converge_when=:generator_conditions,
+)
+
 # Plot:
 p1 = Plots.plot(
     convert2image(MNIST, reshape(x,28,28)),
@@ -147,12 +171,13 @@ p1 = Plots.plot(
 )
 plts = [p1]
 
-ces = zip([ce_wachter,ce_conformal])
+ces = zip([ce_wachter,ce_conformal,ce_conformal_jsma])
+_names = ["Wachter", "CCE", "CCE-JSMA"]
 counterfactuals = reduce((x,y)->cat(x,y,dims=3),map(ce -> CounterfactualExplanations.counterfactual(ce[1]), ces))
 phat = reduce((x,y) -> cat(x,y,dims=3), map(ce -> target_probs(ce[1]), ces))
-for x in zip(eachslice(counterfactuals; dims=3), eachslice(phat; dims=3))
-    ce, _phat = (x[1],x[2])
-    _title = "p(y=$(target)|x′)=$(round(_phat[1]; digits=3))"
+for x in zip(eachslice(counterfactuals; dims=3), _names, eachslice(phat; dims=3))
+    ce, _name, _phat = (x[1],x[2],x[3])
+    _title = "$_name (p̂=$(round(_phat[1]; digits=3)))"
     plt = Plots.plot(
         convert2image(MNIST, reshape(ce,28,28)),
         axis=nothing, 
diff --git a/src/model.jl b/src/model.jl
index c724867b51e2d03c696b991d1d708766525d1cfb..cc7cde3c24186017c91040107d06e1623f187089 100644
--- a/src/model.jl
+++ b/src/model.jl
@@ -52,6 +52,7 @@ function ConformalModel(model, fitresult=nothing; likelihood::Union{Nothing,Symb
     end
 
     # Construct model:
+    testmode!(fitresult[1])
     M = ConformalModel(model, fitresult, likelihood)
     return M
 end
diff --git a/src/penalties.jl b/src/penalties.jl
index 09d2c2a28def9c2f4194647f9baf875d7af7c2d1..30ae653c91f21d06eb794ffa3369e9c5cdb6f9a7 100644
--- a/src/penalties.jl
+++ b/src/penalties.jl
@@ -3,21 +3,21 @@ using LinearAlgebra: norm
 using Statistics: mean
 
 """
-    set_size_penalty(counterfactual_explanation::AbstractCounterfactualExplanation)
+    set_size_penalty(ce::AbstractCounterfactualExplanation)
 
 Penalty for smooth conformal set size.
 """
 function set_size_penalty(
-    counterfactual_explanation::AbstractCounterfactualExplanation; 
+    ce::AbstractCounterfactualExplanation; 
     κ::Real=0.0, temp::Real=0.05, agg=mean
 )
 
-    conf_model = counterfactual_explanation.M.model
-    fitresult = counterfactual_explanation.M.fitresult
-    X = CounterfactualExplanations.decode_state(counterfactual_explanation)
-    loss = map(eachslice(X, dims=3)) do x
-        x = Matrix(x)
-        if target_probs(counterfactual_explanation, x)[1] >= 0.5
+    conf_model = ce.M.model
+    fitresult = ce.M.fitresult
+    X = CounterfactualExplanations.decode_state(ce)
+    loss = map(eachslice(X, dims=ndims(X))) do x
+        x = ndims(x) == 1 ? x[:,:] : x
+        if target_probs(ce, x)[1] >= 0.5
             l = ConformalPrediction.smooth_size_loss(
                 conf_model, fitresult, x';
                 κ=κ,
@@ -35,19 +35,19 @@ function set_size_penalty(
 end
 
 function distance_from_energy(
-    counterfactual_explanation::AbstractCounterfactualExplanation;
+    ce::AbstractCounterfactualExplanation;
     n::Int=10000, from_buffer=true, agg=mean, kwargs...
 )
     conditional_samples = []
     ignore_derivatives() do
-        _dict = counterfactual_explanation.params
+        _dict = ce.params
         if !(:energy_sampler ∈ collect(keys(_dict)))
-            _dict[:energy_sampler] = CCE.EnergySampler(counterfactual_explanation; kwargs...)
+            _dict[:energy_sampler] = CCE.EnergySampler(ce; kwargs...)
         end
         sampler = _dict[:energy_sampler]
         push!(conditional_samples, rand(sampler, n; from_buffer=from_buffer))
     end
-    x′ = CounterfactualExplanations.counterfactual(counterfactual_explanation)
+    x′ = CounterfactualExplanations.counterfactual(ce)
     loss = map(eachslice(x′, dims=3)) do x
         x = Matrix(x)
         Δ = map(eachcol(conditional_samples[1])) do xsample
@@ -62,13 +62,13 @@ function distance_from_energy(
 end
 
 function distance_from_targets(
-    counterfactual_explanation::AbstractCounterfactualExplanation;
+    ce::AbstractCounterfactualExplanation;
     n::Int=10000, agg=mean
 )
-    target_idx = counterfactual_explanation.data.output_encoder.labels .== counterfactual_explanation.target
-    target_samples = counterfactual_explanation.data.X[:,target_idx] |>
+    target_idx = ce.data.output_encoder.labels .== ce.target
+    target_samples = ce.data.X[:,target_idx] |>
         X -> X[:,rand(1:end,n)]
-    x′ = CounterfactualExplanations.counterfactual(counterfactual_explanation)
+    x′ = CounterfactualExplanations.counterfactual(ce)
     loss = map(eachslice(x′, dims=3)) do x
         x = Matrix(x)
         Δ = map(eachcol(target_samples)) do xsample
diff --git a/www/cce_mnist.png b/www/cce_mnist.png
index ea956a80689729139b58966c11514539991a792a..ff97f8657cd5af3cd5f753e84cb38286614bf6ca 100644
Binary files a/www/cce_mnist.png and b/www/cce_mnist.png differ