diff --git a/artifacts/results/mnist_vae.jls b/artifacts/results/mnist_vae.jls
index 872e5a43473de969e478c182fe76d324da88d0dc..0eed13cca664e891d5f09c05ef2f341b4fe67efd 100644
Binary files a/artifacts/results/mnist_vae.jls and b/artifacts/results/mnist_vae.jls differ
diff --git a/artifacts/results/mnist_vae_weak.jls b/artifacts/results/mnist_vae_weak.jls
index 42d8d05d0113c303f09bad12ce13865f76cff0cb..d1c8752d6b813e87b8ae1e8c82bf6125a31a30d7 100644
Binary files a/artifacts/results/mnist_vae_weak.jls and b/artifacts/results/mnist_vae_weak.jls differ
diff --git a/notebooks/mnist.qmd b/notebooks/mnist.qmd
index 3fa961e0a17f40fbc112bee4c215137f186f49b1..b5b68d58d775cc357835f4944f94c83e8610ac10 100644
--- a/notebooks/mnist.qmd
+++ b/notebooks/mnist.qmd
@@ -388,11 +388,6 @@ function _plot_eccco_mnist(
 
     ces = Dict()
     for (mod_name, mod) in model_dict
-        if mod.model.model isa ProbabilisticPipeline
-            x = mod.model.model.f(x)
-            _data = counterfactual_data
-            _data.X = mod.model.model.f(_data.X)
-        end
         ce = generate_counterfactual(
             x, target, counterfactual_data, mod, eccco_generator; 
             decision_threshold=γ, max_iter=T,
@@ -455,7 +450,11 @@ mutable struct LeNetBuilder
 	channels2::Int
 end
 
-function MLJFlux.build(b::LeNetBuilder, rng, n_in, n_out, n_channels)
+preproc(X) = reshape(X, (28, 28, 1, :))
+
+function MLJFlux.build(b::LeNetBuilder, rng, n_in, n_out)
+
+    _n_in = Int(sqrt(n_in))
 
 	k, c1, c2 = b.filter_size, b.channels1, b.channels2
 
@@ -464,42 +463,35 @@ function MLJFlux.build(b::LeNetBuilder, rng, n_in, n_out, n_channels)
 	# padding to preserve image size on convolution:
 	p = div(k - 1, 2)
 
+    preproc(x) = reshape(x, (_n_in, _n_in, 1, :))
+
 	front = Flux.Chain(
-        Conv((k, k), n_channels => c1, pad=(p, p), relu),
+        Conv((k, k), 1 => c1, pad=(p, p), relu),
         MaxPool((2, 2)),
         Conv((k, k), c1 => c2, pad=(p, p), relu),
         MaxPool((2, 2)),
         Flux.flatten
     )
-	d = Flux.outputsize(front, (n_in..., n_channels, 1)) |> first
+	d = Flux.outputsize(front, (_n_in, _n_in, 1, 1)) |> first
     back = Flux.Chain(
         Dense(d, 120, relu),
         Dense(120, 84, relu),
         Dense(84, n_out),
     )
 
-    chain = Flux.Chain(front, back)
+    chain = Flux.Chain(preproc, front, back)
 
 	return chain
 end
 
 # Final model:
-lenet = ImageClassifier(
+lenet = NeuralNetworkClassifier(
     builder=LeNetBuilder(5, 6, 16),
     epochs=epochs,
     batch_size=batch_size,
     finaliser=_finaliser,
     loss=_loss,
 )
-
-# Convert to image:
-function toimg(X)
-    X = matrix(X) |> x -> reshape(x, 28, 28, :)
-    X = coerce(X, GrayImage)
-    return X
-end
-
-lenet = (X -> toimg(X)) |> lenet
 ```
 
 ```{julia}
@@ -508,8 +500,12 @@ add_retrain = true
 # Deep Ensemble:
 mlp_large_ens = EnsembleModel(model=mlp, n=50)
 
+# LeNet-5 Ensemble:
+lenet_ens = EnsembleModel(model=lenet, n=5)
+
 add_models = Dict(
     "LeNet-5" => lenet,
+    "LeNet-5 Ensemble" => lenet_ens,
     "Large Ensemble (n=50)" => mlp_large_ens,
 )
 
@@ -528,6 +524,7 @@ _plt_order = [
     "MLP Ensemble", 
     "Large Ensemble (n=50)", 
     "LeNet-5",
+    "LeNet-5 Ensemble",
     "JEM", 
     "JEM Ensemble",
 ]
diff --git a/src/model.jl b/src/model.jl
index bf3dbe8a49dc235f0209f94f428f555e398493ad..a865d55998043c5831463e3866d164bd092a3d7b 100644
--- a/src/model.jl
+++ b/src/model.jl
@@ -1,3 +1,4 @@
+using ChainRules: ignore_derivatives
 using ConformalPrediction
 using CounterfactualExplanations.Models
 using Flux
@@ -38,13 +39,18 @@ end
 Private function that extracts the chains from a fitted model.
 """
 function _get_chains(fitresult)
-    if fitresult isa MLJEnsembles.WrappedEnsemble
-        chains = map(res -> res[1], fitresult.ensemble)
-    elseif fitresult isa MLJBase.Signature
-        chains = [fitted_params(fitresult)[:image_classifier][1]]      # for piped image classifiers
-    else
-        chains = [fitresult[1]]
+    
+    chains = []
+
+    ignore_derivatives() do 
+        if fitresult isa MLJEnsembles.WrappedEnsemble
+            _chains = map(res -> res[1], fitresult.ensemble)
+        else
+            _chains = [fitresult[1]]
+        end
+        push!(chains, _chains...)
     end
+    
     return chains
 end
 
@@ -147,7 +153,9 @@ In the binary case logits are fed through the sigmoid function instead of softma
 which follows from the derivation here: https://stats.stackexchange.com/questions/233658/softmax-vs-sigmoid-function-in-logistic-classifier
 """
 function Models.logits(M::ConformalModel, X::AbstractArray)
+    
     fitresult = M.fitresult
+
     function predict_logits(fitresult, x)
         ŷ = MLUtils.stack(map(chain -> get_logits(chain,x),_get_chains(fitresult))) |> 
             y -> mean(y, dims=ndims(y)) |>
@@ -162,14 +170,9 @@ function Models.logits(M::ConformalModel, X::AbstractArray)
         ŷ = ndims(ŷ) > 1 ? ŷ : permutedims([ŷ])
         return ŷ
     end
-    if ndims(X) > 2
-        yhat = map(eachslice(X, dims=ndims(X))) do x
-            predict_logits(fitresult, x)
-        end
-        yhat = MLUtils.stack(yhat)
-    else
-        yhat = predict_logits(fitresult, X)
-    end
+
+    yhat = predict_logits(fitresult, X)
+
     return yhat
 end
 
@@ -179,6 +182,7 @@ end
 Returns the estimated probabilities for a Conformal Classifier.
 """
 function Models.probs(M::ConformalModel, X::AbstractArray)
+
     if M.likelihood == :classification_binary
         output = σ.(Models.logits(M, X))
     elseif M.likelihood == :classification_multi