diff --git a/experiments/models/train_models.jl b/experiments/models/train_models.jl index 635243444124bd2dcaee329e2cd37f5987103290..14d00675091cdf84e6073021ff32a113eba7fc61 100644 --- a/experiments/models/train_models.jl +++ b/experiments/models/train_models.jl @@ -6,6 +6,7 @@ using CounterfactualExplanations: AbstractParallelizer Trains all models in a dictionary and returns a dictionary of `ConformalModel` objects. """ function train_models(models::Dict, X, y; parallelizer::Union{Nothing,AbstractParallelizer}=nothing, train_parallel::Bool=false, kwargs...) + verbose = is_multi_processed(parallelizer) ? false : true if is_multi_processed(parallelizer) && train_parallel # Split models into groups of approximately equal size: model_list = [(key, value) for (key, value) in models] @@ -14,7 +15,7 @@ function train_models(models::Dict, X, y; parallelizer::Union{Nothing,AbstractPa # Train models: model_dict = Dict() for (mod_name, model) in x - model_dict[mod_name] = _train(model, X, y; mod_name=mod_name, verbose=false, kwargs...) + model_dict[mod_name] = _train(model, X, y; mod_name=mod_name, verbose=verbose, kwargs...) end MPI.Barrier(parallelizer.comm) output = MPI.gather(output, parallelizer.comm) @@ -28,7 +29,7 @@ function train_models(models::Dict, X, y; parallelizer::Union{Nothing,AbstractPa model_dict = MPI.bcast(output, parallelizer.comm; root=0) MPI.Barrier(parallelizer.comm) else - model_dict = Dict(mod_name => _train(model, X, y; mod_name=mod_name, kwargs...) for (mod_name, model) in models) + model_dict = Dict(mod_name => _train(model, X, y; mod_name=mod_name, verbose=verbose, kwargs...) for (mod_name, model) in models) end return model_dict end diff --git a/experiments/setup_env.jl b/experiments/setup_env.jl index f32ddaa5c992b9e898da18f9a568f82769e75697..fc9217a1b0a9ad4c01c32846f1834370583ff9c1 100644 --- a/experiments/setup_env.jl +++ b/experiments/setup_env.jl @@ -27,6 +27,8 @@ using TidierData Random.seed!(2023) +ENV["DATADEPS_ALWAYS_ACCEPT"] = "true" # avoid command prompt and just download data + # Scripts: include("experiment.jl") include("data/data.jl") @@ -124,7 +126,7 @@ if any(contains.(ARGS, "n_individuals=")) n_ind_specified = true n_individuals = ARGS[findall(contains.(ARGS, "n_individuals="))][1] |> x -> replace(x, "n_individuals=" => "") else - n_individuals = 25 + n_individuals = 100 end "Number of individuals to use in benchmarking." diff --git a/experiments/synthetic.sh b/experiments/synthetic.sh index 5cf4c35901363e9f4fdaeb82c03667444f481cdc..1ff7acac3f87ea326c2bfc99aab45375beef9893 100644 --- a/experiments/synthetic.sh +++ b/experiments/synthetic.sh @@ -11,4 +11,4 @@ module load 2023r1 openmpi -srun julia --project=experiments experiments/run_experiments.jl -- data=linearly_separable,moons,cricles output_path=results retrain threaded mpi > experiments/synthetic.log +srun julia --project=experiments experiments/run_experiments.jl -- data=linearly_separable,moons,circles output_path=results retrain threaded mpi > experiments/synthetic.log