diff --git a/notebooks/mnist.qmd b/notebooks/mnist.qmd
index 433e3ba30df6509ee002fed887ba04f5fa7bef51..2582057672d30489bc1b66ae9032965004b129f6 100644
--- a/notebooks/mnist.qmd
+++ b/notebooks/mnist.qmd
@@ -25,8 +25,21 @@ epochs = 100
 batch_size = Int(round(n_obs/10))
 n_hidden = 32
 activation = Flux.swish
-builder = MLJFlux.MLP(hidden=(n_hidden,), σ=activation)
-α = [0.33,1.0,1e-1]
+builder = MLJFlux.@builder Flux.Chain(
+    Dense(n_in, n_hidden),
+    BatchNorm(n_hidden, activation),
+    Dense(n_hidden, n_out),
+    BatchNorm(n_out)
+)
+# builder = MLJFlux.MLP(
+#     hidden=(
+#         n_hidden,
+#         n_hidden,
+#         n_hidden,
+#     ), 
+#     σ=activation
+# )
+α = [1.0,1.0,1e-1]
 
 # Simple MLP:
 mlp = NeuralNetworkClassifier(
@@ -36,17 +49,18 @@ mlp = NeuralNetworkClassifier(
 )
 
 # Joint Energy Model:
-𝒟x = Uniform(-1,1)
+𝒟x = Uniform(0,1)
 𝒟y = Categorical(ones(output_dim) ./ output_dim)
 sampler = ConditionalSampler(𝒟x, 𝒟y, input_size=(input_dim,), batch_size=batch_size)
 jem = JointEnergyClassifier(
     sampler;
     builder=builder,
     batch_size=batch_size,
-    finaliser=Flux.softmax,
-    loss=Flux.Losses.crossentropy,
+    finaliser=x -> x,
+    loss=Flux.Losses.logitcrossentropy,
     jem_training_params=(α=α,verbosity=10,),
     sampling_steps=20,
+    epochs=epochs,
 )
 
 # Deep Ensemble:
@@ -54,8 +68,8 @@ mlp_ens = EnsembleModel(model=mlp, n=5)
 ```
 
 ```{julia}
-cov = .9
-conf_model = conformal_model(jem; method=:adaptive_inductive, coverage=cov)
+cov = .90
+conf_model = conformal_model(jem; method=:simple_inductive, coverage=cov)
 mach = machine(conf_model, X, labels)
 fit!(mach)
 M = CCE.ConformalModel(mach.model, mach.fitresult)
@@ -63,7 +77,7 @@ M = CCE.ConformalModel(mach.model, mach.fitresult)
 
 ```{julia}
 jem = mach.model.model.jem
-n_iter = 5000
+n_iter = 100
 _w = 1500
 plts = []
 neach = 10
@@ -78,7 +92,8 @@ for i in 1:10
     plt = Plots.plot(plts_i..., size=(_w,0.10*_w), layout=(1,10))
     plts = [plts..., plt]
 end
-Plots.plot(plts..., size=(_w,_w), layout=(10,1))
+plt = Plots.plot(plts..., size=(_w,_w), layout=(10,1))
+display(plt)
 ```
 
 ```{julia}
@@ -88,12 +103,14 @@ println("F1 score (test): $(round(f1,digits=3))")
 ```
 
 ```{julia}
+Random.seed!(1234)
+
 # Set up search:
-factual_label = 9
+factual_label = 2
 x = reshape(counterfactual_data.X[:,rand(findall(predict_label(M, counterfactual_data).==factual_label))],input_dim,1)
-target = 4
+target = 0
 factual = predict_label(M, counterfactual_data, x)[1]
-γ = 0.9
+γ = 0.5
 T = 100
 
 # Generate counterfactual using generic generator:
@@ -106,9 +123,9 @@ ce_wachter = generate_counterfactual(
 
 # Generate counterfactual using CCE generator:
 generator = CCEGenerator(
-    λ=[0.0,10.0], 
-    temp=0.01, 
-    # opt=CounterfactualExplanations.Generators.JSMADescent(η=0.5),
+    λ=[0.0,1.0], 
+    temp=0.5, 
+    # opt=CounterfactualExplanations.Generators.JSMADescent(η=1.0),
 )
 ce_conformal = generate_counterfactual(
     x, target, counterfactual_data, M, generator; 
diff --git a/src/model.jl b/src/model.jl
index 2d3ec05df4434394173ee7f9ec781331e1a2200d..c724867b51e2d03c696b991d1d708766525d1cfb 100644
--- a/src/model.jl
+++ b/src/model.jl
@@ -86,7 +86,11 @@ function Models.logits(M::ConformalModel, X::AbstractArray)
             p̂ = [p̂]
         end
         p̂ = reduce(hcat, p̂)
-        ŷ = reduce(hcat, (map(p -> log.(p) .+ log(sum(exp.(p))), eachcol(p̂))))
+        if all(0.0 .<= vec(p̂) .<= 1.0)
+            ŷ = reduce(hcat, (map(p -> log.(p) .+ log(sum(exp.(p))), eachcol(p̂))))
+        else
+            ŷ = p̂
+        end
         if M.likelihood == :classification_binary
             ŷ = reduce(hcat, (map(y -> y[2] - y[1], eachcol(ŷ))))
         end
diff --git a/www/cce_mnist.png b/www/cce_mnist.png
index 8af78f6df7b2364c9d6883f4b5dea1046c1ac173..3706287ac1d5a3ee59587e8336bb1e687f5a6286 100644
Binary files a/www/cce_mnist.png and b/www/cce_mnist.png differ