diff --git a/notebooks/Manifest.toml b/notebooks/Manifest.toml
index edccb799635413a5efd692ffbc26e1aa806f8fcb..c0afb2087ea078b7fbf8631438e5f0ff39b87485 100644
--- a/notebooks/Manifest.toml
+++ b/notebooks/Manifest.toml
@@ -2,7 +2,7 @@
 
 julia_version = "1.8.5"
 manifest_format = "2.0"
-project_hash = "512d9080e47cd18c9e3a640716f9947dd8512bcb"
+project_hash = "70fd26ae44f12d0456544e252f519edafee7b553"
 
 [[deps.AbstractFFTs]]
 deps = ["ChainRulesCore", "LinearAlgebra"]
diff --git a/notebooks/Project.toml b/notebooks/Project.toml
index 538c28795473ee6ed358add88b6a4a47de8d10cb..614baa9d9bb2db96651f20a59af9ab5115217362 100644
--- a/notebooks/Project.toml
+++ b/notebooks/Project.toml
@@ -21,4 +21,5 @@ MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
 Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
 Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
 Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
+StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
 Tidier = "f0413319-3358-4bb0-8e7c-0c83523a93bd"
diff --git a/notebooks/mnist.qmd b/notebooks/mnist.qmd
index c04ca57ef70b11f8d4239ff7fc307fdb25d28318..7a1a41dd51db4beff52d2b9a18c46844b4ed350e 100644
--- a/notebooks/mnist.qmd
+++ b/notebooks/mnist.qmd
@@ -7,7 +7,7 @@ eval(setup_notebooks)
 
 ```{julia}
 # Data:
-counterfactual_data = load_mnist(1000)
+counterfactual_data = load_mnist()
 X, y = CounterfactualExplanations.DataPreprocessing.unpack_data(counterfactual_data)
 X = table(permutedims(X))
 labels = counterfactual_data.output_encoder.labels
@@ -16,7 +16,11 @@ input_dim, n_obs = size(counterfactual_data.X)
 
 ```{julia}
 epochs = 100
-clf = NeuralNetworkClassifier(builder=MLJFlux.MLP(hidden=(32,), σ=relu), epochs=epochs)
+clf = NeuralNetworkClassifier(
+    builder=MLJFlux.MLP(hidden=(32,), σ=relu), 
+    epochs=epochs,
+    batch_size=Int(round(n_obs/10))
+)
 conf_model = conformal_model(clf; method=:simple_inductive)
 mach = machine(conf_model, X, labels)
 fit!(mach)
@@ -32,11 +36,17 @@ f1 = CounterfactualExplanations.Models.model_evaluation(M, test_data)
 println("F1 score (test): $(round(f1,digits=3))")
 ```
 
+```{julia}
+using CounterfactualExplanations.DataPreprocessing: undersample
+dt_reduced = undersample(counterfactual_data, 100)
+# dt_reduced = counterfactual_data
+```
+
 ```{julia}
 # Set up search:
-factual_label = 8
+factual_label = 9
 x = reshape(counterfactual_data.X[:,rand(findall(predict_label(M, counterfactual_data).==factual_label))],input_dim,1)
-target = 3
+target = 4
 factual = predict_label(M, counterfactual_data, x)[1]
 γ = 0.9
 T = 100
@@ -44,15 +54,19 @@ T = 100
 # Generate counterfactual using generic generator:
 generator = GenericGenerator()
 ce_wachter = generate_counterfactual(
-    x, target, counterfactual_data, M, generator; 
+    x, target, dt_reduced, M, generator; 
     decision_threshold=γ, max_iter=T,
     initialization=:identity,
 )
 
 # Generate counterfactual using CCE generator:
-generator = CCEGenerator(λ=[0.0,10.0], temp=0.01, opt=CounterfactualExplanations.Generators.JSMADescent(η=5.0))
+generator = CCEGenerator(
+    λ=[0.0,10.0], 
+    temp=0.01, 
+    # opt=CounterfactualExplanations.Generators.JSMADescent(η=5.0),
+)
 ce_conformal = generate_counterfactual(
-    x, target, counterfactual_data, M, generator; 
+    x, target, dt_reduced, M, generator; 
     decision_threshold=γ, max_iter=T,
     initialization=:identity,
     converge_when=:generator_conditions,
diff --git a/www/cce_mnist.png b/www/cce_mnist.png
index 85eed69096b075c1519480191c1693cf0ceaf950..6de563ed4b67f3c267336a85556098625edc0b55 100644
Binary files a/www/cce_mnist.png and b/www/cce_mnist.png differ