From f19dceceb432d4087778cb7a0785f35221e5a665 Mon Sep 17 00:00:00 2001
From: Pat Alt <55311242+pat-alt@users.noreply.github.com>
Date: Wed, 13 Sep 2023 21:43:37 +0200
Subject: [PATCH] trying from pre-trained again

---
 experiments/models/models.jl | 16 +++++++++++++---
 1 file changed, 13 insertions(+), 3 deletions(-)

diff --git a/experiments/models/models.jl b/experiments/models/models.jl
index 8c394144..9aa63069 100644
--- a/experiments/models/models.jl
+++ b/experiments/models/models.jl
@@ -47,9 +47,19 @@ function prepare_models(exper::Experiment)
         @info "Training models."
         model_dict = train_models(models, X, labels; parallelizer=exper.parallelizer, train_parallel=exper.train_parallel, cov=exper.coverage)
     else
-        # Load models on root process:
-        @info "Loading pre-trained models."
-        model_dict = Serialization.deserialize(joinpath(pretrained_path(exper), "$(exper.save_name)_models.jls"))
+        # Pre-trained models:
+        if !(is_multi_processed(exper) && MPI.Comm_rank(exper.parallelizer.comm) != 0)
+            # Load models on root process:
+            @info "Loading pre-trained models."
+            model_dict = Serialization.deserialize(joinpath(pretrained_path(exper), "$(exper.save_name)_models.jls"))
+        else
+            # Dummy model on other processes: 
+            model_dict = nothing
+        end
+        # Broadcast models:
+        if is_multi_processed(exper)
+            model_dict = MPI.bcast(model_dict, exper.parallelizer.comm; root=0)
+        end
     end
 
     # Save models:
-- 
GitLab