diff --git a/experiments/gmsc.jl b/experiments/gmsc.jl
index f8ecaa4d60d29b3c1ed5511e6956598094ea9090..f9c8120ef8f4d74dcd85301b3e23c4143eeda81d 100644
--- a/experiments/gmsc.jl
+++ b/experiments/gmsc.jl
@@ -1,14 +1,18 @@
 counterfactual_data, test_data = train_test_split(load_gmsc(nothing); test_size=TEST_SIZE)
+
+# Default builder:
+n_hidden = 128
+activation = Flux.swish
+builder = MLJFlux.@builder Flux.Chain(
+    Dense(n_in, n_hidden, activation),
+    Dense(n_hidden, n_hidden, activation),
+    Dense(n_hidden, n_out),
+)
+
 run_experiment(
     counterfactual_data, test_data; 
     dataname="GMSC",
-    n_hidden=128,
-    activation = Flux.swish,
-    builder = MLJFlux.@builder Flux.Chain(
-        Dense(n_in, n_hidden, activation),
-        Dense(n_hidden, n_hidden, activation),
-        Dense(n_hidden, n_out),
-    ),
+    builder = builder,
     α=[1.0, 1.0, 1e-1],
     sampling_batch_size=10,
     sampling_steps = 30,
diff --git a/experiments/mnist.jl b/experiments/mnist.jl
index 21cb7b7924ccabb210dcafd3eb6dde6d9bdb9761..372552e28fb5ab9cb94427177b08eace3e73aefa 100644
--- a/experiments/mnist.jl
+++ b/experiments/mnist.jl
@@ -22,16 +22,19 @@ add_models = Dict(
     :lenet5 => lenet5,
 )
 
+# Default builder:
+n_hidden = 128
+activation = Flux.swish
+builder = MLJFlux.@builder Flux.Chain(
+    Dense(n_in, n_hidden, activation),
+    Dense(n_hidden, n_out),
+)
+
 # Run:
 run_experiment(
     counterfactual_data, test_data; 
     dataname="MNIST",
-    n_hidden = 128,
-    activation = Flux.swish,
-    builder= MLJFlux.@builder Flux.Chain(
-        Dense(n_in, n_hidden, activation),
-        Dense(n_hidden, n_out),
-    ),
+    builder= builder,
     𝒟x = Uniform(-1.0, 1.0),
     α = [1.0,1.0,1e-2],
     sampling_batch_size = 10,
@@ -42,6 +45,6 @@ run_experiment(
     nmin = 10,
     use_variants = false,
     use_class_loss = true,
-    add_models = add_models,
+    additional_models=add_models,
     epochs = 10,
 )
\ No newline at end of file
diff --git a/experiments/models/models.jl b/experiments/models/models.jl
index 89975bcd1c6800b606f9e069e80061af5f5ff96b..bd2d5026cfb1dc1173b8091e3638dc89ab7f010b 100644
--- a/experiments/models/models.jl
+++ b/experiments/models/models.jl
@@ -11,6 +11,8 @@ function prepare_models(exp::Experiment)
     if !exp.use_pretrained
         if isnothing(exp.builder)
             builder = default_builder()
+        else
+            builder = exp.builder
         end
         # Default models:
         if isnothing(exp.models)
@@ -25,6 +27,7 @@ function prepare_models(exp::Experiment)
                 use_ensembling=exp.use_ensembling,
                 finaliser=exp.finaliser,
                 loss=exp.loss,
+                epochs=exp.epochs,
             )
         end
         # Additional models:
@@ -36,10 +39,10 @@ function prepare_models(exp::Experiment)
                     batch_size=batch_size(exp),
                     finaliser=exp.finaliser,
                     loss=exp.loss,
+                    epochs=exp.epochs,
                 )
             end
-            add_models = Dict(k => mod(;batch_size=batch_size(exp), ) for (k, mod) in exp.additional_models)
-            models = merge!(models, exp.additional_models)
+            models = merge(models, add_models)
         end
         @info "Training models."
         model_dict = train_models(models, X, labels; cov=exp.coverage)
diff --git a/experiments/setup_env.jl b/experiments/setup_env.jl
index 1d6a3be32e3c5ceaa6074ab12b8c004712488bbb..c6657bdd40e2f0e0280b901cf6b3a91d662e93dc 100644
--- a/experiments/setup_env.jl
+++ b/experiments/setup_env.jl
@@ -12,7 +12,7 @@ using CounterfactualExplanations.Parallelization
 using CSV
 using Dates
 using DataFrames
-using Distributions: Normal, Distribution, Categorical
+using Distributions: Normal, Distribution, Categorical, Uniform
 using ECCCo
 using Flux
 using JointEnergyModels