diff --git a/experiments/models/models.jl b/experiments/models/models.jl index 8c3941443c50f4721e3046c8cdc11385236fd295..9aa6306929898ddd546e082ab636dd8ce4c0466b 100644 --- a/experiments/models/models.jl +++ b/experiments/models/models.jl @@ -47,9 +47,19 @@ function prepare_models(exper::Experiment) @info "Training models." model_dict = train_models(models, X, labels; parallelizer=exper.parallelizer, train_parallel=exper.train_parallel, cov=exper.coverage) else - # Load models on root process: - @info "Loading pre-trained models." - model_dict = Serialization.deserialize(joinpath(pretrained_path(exper), "$(exper.save_name)_models.jls")) + # Pre-trained models: + if !(is_multi_processed(exper) && MPI.Comm_rank(exper.parallelizer.comm) != 0) + # Load models on root process: + @info "Loading pre-trained models." + model_dict = Serialization.deserialize(joinpath(pretrained_path(exper), "$(exper.save_name)_models.jls")) + else + # Dummy model on other processes: + model_dict = nothing + end + # Broadcast models: + if is_multi_processed(exper) + model_dict = MPI.bcast(model_dict, exper.parallelizer.comm; root=0) + end end # Save models: