diff --git a/experiments/mnist.sh b/experiments/mnist.sh
index 8823002ac8addd0a41b6f90d0b7d20dfb903aa25..87dd05d7fc4608cd967bfc06f0681ea9f13eb438 100644
--- a/experiments/mnist.sh
+++ b/experiments/mnist.sh
@@ -11,4 +11,4 @@
 
 module load 2023r1 openmpi
 
-srun julia --project=experiments experiments/run_experiments.jl -- data=mnist output_path=results mpi > experiments/mnist.log
+srun julia --project=experiments experiments/run_experiments.jl -- data=mnist output_path=results mpi retrain > experiments/mnist.log
diff --git a/experiments/models/models.jl b/experiments/models/models.jl
index 63c1c4eb07735a3059148fe04355fe5e9d4f5f48..8c3941443c50f4721e3046c8cdc11385236fd295 100644
--- a/experiments/models/models.jl
+++ b/experiments/models/models.jl
@@ -47,20 +47,9 @@ function prepare_models(exper::Experiment)
         @info "Training models."
         model_dict = train_models(models, X, labels; parallelizer=exper.parallelizer, train_parallel=exper.train_parallel, cov=exper.coverage)
     else
-        # Pre-trained models:
-        if !(is_multi_processed(exper) && MPI.Comm_rank(exper.parallelizer.comm) != 0)
-            # Load models on root process:
-            @info "Loading pre-trained models."
-            model_dict = Serialization.deserialize(joinpath(pretrained_path(exper), "$(exper.save_name)_models.jls"))
-        else
-            # Dummy model on other processes: 
-            model_dict = nothing
-        end
-        # Broadcast models:
-        if is_multi_processed(exper)
-            model_dict = MPI.bcast(model_dict, exper.parallelizer.comm; root=0)
-            MPI.Barrier(exper.parallelizer.comm)
-        end
+        # Load models on root process:
+        @info "Loading pre-trained models."
+        model_dict = Serialization.deserialize(joinpath(pretrained_path(exper), "$(exper.save_name)_models.jls"))
     end
 
     # Save models: