diff --git a/experiments/Manifest.toml b/experiments/Manifest.toml index f7c37dd6eaff0a738e5a24d25893c715b0b35bb0..9027f359713afa50cb343df2ac2a1746615238a4 100644 --- a/experiments/Manifest.toml +++ b/experiments/Manifest.toml @@ -2232,10 +2232,10 @@ version = "1.3.0" InverseFunctions = "3587e190-3f89-42d0-90ee-14403ec27112" [[deps.StatsModels]] -deps = ["DataAPI", "DataStructures", "LinearAlgebra", "Printf", "REPL", "ShiftedArrays", "SparseArrays", "StatsBase", "StatsFuns", "Tables"] -git-tree-sha1 = "8cc7a5385ecaa420f0b3426f9b0135d0df0638ed" +deps = ["DataAPI", "DataStructures", "LinearAlgebra", "Printf", "REPL", "ShiftedArrays", "SparseArrays", "StatsAPI", "StatsBase", "StatsFuns", "Tables"] +git-tree-sha1 = "5cf6c4583533ee38639f73b880f35fc85f2941e0" uuid = "3eaba693-59b7-5ba5-a881-562e759f1c8d" -version = "0.7.2" +version = "0.7.3" [[deps.Strided]] deps = ["LinearAlgebra", "TupleTools"] diff --git a/experiments/linearly_separable.jl b/experiments/linearly_separable.jl index faa3bf278e56aa1d9ef5c3ced6a2dc43b2f361b0..0e54a672a2d3ff25e084f738b777dfd4546c9b22 100644 --- a/experiments/linearly_separable.jl +++ b/experiments/linearly_separable.jl @@ -6,6 +6,7 @@ counterfactual_data, test_data = train_test_split( run_experiment( counterfactual_data, test_data; dataname="Linearly Separable", - nsamples=1, + nsamples=100, nmin=1, + niter_eccco=30 ) \ No newline at end of file diff --git a/experiments/models/models.jl b/experiments/models/models.jl index 900e95b9637bbe095e4b1997a93ee70015c50ae0..63c1c4eb07735a3059148fe04355fe5e9d4f5f48 100644 --- a/experiments/models/models.jl +++ b/experiments/models/models.jl @@ -47,10 +47,16 @@ function prepare_models(exper::Experiment) @info "Training models." model_dict = train_models(models, X, labels; parallelizer=exper.parallelizer, train_parallel=exper.train_parallel, cov=exper.coverage) else + # Pre-trained models: if !(is_multi_processed(exper) && MPI.Comm_rank(exper.parallelizer.comm) != 0) + # Load models on root process: @info "Loading pre-trained models." model_dict = Serialization.deserialize(joinpath(pretrained_path(exper), "$(exper.save_name)_models.jls")) + else + # Dummy model on other processes: + model_dict = nothing end + # Broadcast models: if is_multi_processed(exper) model_dict = MPI.bcast(model_dict, exper.parallelizer.comm; root=0) MPI.Barrier(exper.parallelizer.comm)