diff --git a/artifacts/results/mnist_vae.jls b/artifacts/results/mnist_vae.jls
index a55474be7a30c77ab1b046ad0c2794523c514d94..872e5a43473de969e478c182fe76d324da88d0dc 100644
Binary files a/artifacts/results/mnist_vae.jls and b/artifacts/results/mnist_vae.jls differ
diff --git a/artifacts/results/mnist_vae_weak.jls b/artifacts/results/mnist_vae_weak.jls
index 8ed61727f0b7c3e5775110beeca6df7f9c9d10ca..42d8d05d0113c303f09bad12ce13865f76cff0cb 100644
Binary files a/artifacts/results/mnist_vae_weak.jls and b/artifacts/results/mnist_vae_weak.jls differ
diff --git a/notebooks/mnist.qmd b/notebooks/mnist.qmd
index feedfb983c656e54325ec8ae822e7ddb7bd7cc7b..3fa961e0a17f40fbc112bee4c215137f186f49b1 100644
--- a/notebooks/mnist.qmd
+++ b/notebooks/mnist.qmd
@@ -388,6 +388,11 @@ function _plot_eccco_mnist(
 
     ces = Dict()
     for (mod_name, mod) in model_dict
+        if mod.model.model isa ProbabilisticPipeline
+            x = mod.model.model.f(x)
+            _data = counterfactual_data
+            _data.X = mod.model.model.f(_data.X)
+        end
         ce = generate_counterfactual(
             x, target, counterfactual_data, mod, eccco_generator; 
             decision_threshold=γ, max_iter=T,
@@ -460,18 +465,22 @@ function MLJFlux.build(b::LeNetBuilder, rng, n_in, n_out, n_channels)
 	p = div(k - 1, 2)
 
 	front = Flux.Chain(
-			   Conv((k, k), n_channels => c1, pad=(p, p), relu),
-			   MaxPool((2, 2)),
-			   Conv((k, k), c1 => c2, pad=(p, p), relu),
-			   MaxPool((2, 2)),
-			   Flux.flatten)
+        Conv((k, k), n_channels => c1, pad=(p, p), relu),
+        MaxPool((2, 2)),
+        Conv((k, k), c1 => c2, pad=(p, p), relu),
+        MaxPool((2, 2)),
+        Flux.flatten
+    )
 	d = Flux.outputsize(front, (n_in..., n_channels, 1)) |> first
     back = Flux.Chain(
-        Dense(256, 120, relu),
+        Dense(d, 120, relu),
         Dense(120, 84, relu),
         Dense(84, n_out),
     )
-	return Flux.Chain(front, back)
+
+    chain = Flux.Chain(front, back)
+
+	return chain
 end
 
 # Final model:
@@ -500,8 +509,8 @@ add_retrain = true
 mlp_large_ens = EnsembleModel(model=mlp, n=50)
 
 add_models = Dict(
-    "Large Ensemble" => mlp_large_ens,
     "LeNet-5" => lenet,
+    "Large Ensemble (n=50)" => mlp_large_ens,
 )
 
 if add_retrain
@@ -517,7 +526,8 @@ end
 _plt_order = [
     "MLP", 
     "MLP Ensemble", 
-    "Large Ensemble", 
+    "Large Ensemble (n=50)", 
+    "LeNet-5",
     "JEM", 
     "JEM Ensemble",
 ]
diff --git a/src/model.jl b/src/model.jl
index 57fcf48a2a1e38ce81712c6a7eae3cdf5c223491..bf3dbe8a49dc235f0209f94f428f555e398493ad 100644
--- a/src/model.jl
+++ b/src/model.jl
@@ -40,6 +40,8 @@ Private function that extracts the chains from a fitted model.
 function _get_chains(fitresult)
     if fitresult isa MLJEnsembles.WrappedEnsemble
         chains = map(res -> res[1], fitresult.ensemble)
+    elseif fitresult isa MLJBase.Signature
+        chains = [fitted_params(fitresult)[:image_classifier][1]]      # for piped image classifiers
     else
         chains = [fitresult[1]]
     end