From 41299aae1dcca42d70a0eaebb621a2b3726f33c9 Mon Sep 17 00:00:00 2001 From: pat-alt <altmeyerpat@gmail.com> Date: Wed, 13 Sep 2023 11:26:03 +0200 Subject: [PATCH] bloody hell --- experiments/Manifest.toml | 6 +++--- experiments/linearly_separable.jl | 3 ++- experiments/models/models.jl | 6 ++++++ 3 files changed, 11 insertions(+), 4 deletions(-) diff --git a/experiments/Manifest.toml b/experiments/Manifest.toml index f7c37dd6..9027f359 100644 --- a/experiments/Manifest.toml +++ b/experiments/Manifest.toml @@ -2232,10 +2232,10 @@ version = "1.3.0" InverseFunctions = "3587e190-3f89-42d0-90ee-14403ec27112" [[deps.StatsModels]] -deps = ["DataAPI", "DataStructures", "LinearAlgebra", "Printf", "REPL", "ShiftedArrays", "SparseArrays", "StatsBase", "StatsFuns", "Tables"] -git-tree-sha1 = "8cc7a5385ecaa420f0b3426f9b0135d0df0638ed" +deps = ["DataAPI", "DataStructures", "LinearAlgebra", "Printf", "REPL", "ShiftedArrays", "SparseArrays", "StatsAPI", "StatsBase", "StatsFuns", "Tables"] +git-tree-sha1 = "5cf6c4583533ee38639f73b880f35fc85f2941e0" uuid = "3eaba693-59b7-5ba5-a881-562e759f1c8d" -version = "0.7.2" +version = "0.7.3" [[deps.Strided]] deps = ["LinearAlgebra", "TupleTools"] diff --git a/experiments/linearly_separable.jl b/experiments/linearly_separable.jl index faa3bf27..0e54a672 100644 --- a/experiments/linearly_separable.jl +++ b/experiments/linearly_separable.jl @@ -6,6 +6,7 @@ counterfactual_data, test_data = train_test_split( run_experiment( counterfactual_data, test_data; dataname="Linearly Separable", - nsamples=1, + nsamples=100, nmin=1, + niter_eccco=30 ) \ No newline at end of file diff --git a/experiments/models/models.jl b/experiments/models/models.jl index 900e95b9..63c1c4eb 100644 --- a/experiments/models/models.jl +++ b/experiments/models/models.jl @@ -47,10 +47,16 @@ function prepare_models(exper::Experiment) @info "Training models." model_dict = train_models(models, X, labels; parallelizer=exper.parallelizer, train_parallel=exper.train_parallel, cov=exper.coverage) else + # Pre-trained models: if !(is_multi_processed(exper) && MPI.Comm_rank(exper.parallelizer.comm) != 0) + # Load models on root process: @info "Loading pre-trained models." model_dict = Serialization.deserialize(joinpath(pretrained_path(exper), "$(exper.save_name)_models.jls")) + else + # Dummy model on other processes: + model_dict = nothing end + # Broadcast models: if is_multi_processed(exper) model_dict = MPI.bcast(model_dict, exper.parallelizer.comm; root=0) MPI.Barrier(exper.parallelizer.comm) -- GitLab