From 723f6ccc909eff4a698e79381f445dd2f9a06798 Mon Sep 17 00:00:00 2001 From: Pat Alt <55311242+pat-alt@users.noreply.github.com> Date: Tue, 12 Sep 2023 07:10:56 +0200 Subject: [PATCH] minor things --- experiments/models/train_models.jl | 5 +++-- experiments/setup_env.jl | 4 +++- experiments/synthetic.sh | 2 +- 3 files changed, 7 insertions(+), 4 deletions(-) diff --git a/experiments/models/train_models.jl b/experiments/models/train_models.jl index 63524344..14d00675 100644 --- a/experiments/models/train_models.jl +++ b/experiments/models/train_models.jl @@ -6,6 +6,7 @@ using CounterfactualExplanations: AbstractParallelizer Trains all models in a dictionary and returns a dictionary of `ConformalModel` objects. """ function train_models(models::Dict, X, y; parallelizer::Union{Nothing,AbstractParallelizer}=nothing, train_parallel::Bool=false, kwargs...) + verbose = is_multi_processed(parallelizer) ? false : true if is_multi_processed(parallelizer) && train_parallel # Split models into groups of approximately equal size: model_list = [(key, value) for (key, value) in models] @@ -14,7 +15,7 @@ function train_models(models::Dict, X, y; parallelizer::Union{Nothing,AbstractPa # Train models: model_dict = Dict() for (mod_name, model) in x - model_dict[mod_name] = _train(model, X, y; mod_name=mod_name, verbose=false, kwargs...) + model_dict[mod_name] = _train(model, X, y; mod_name=mod_name, verbose=verbose, kwargs...) end MPI.Barrier(parallelizer.comm) output = MPI.gather(output, parallelizer.comm) @@ -28,7 +29,7 @@ function train_models(models::Dict, X, y; parallelizer::Union{Nothing,AbstractPa model_dict = MPI.bcast(output, parallelizer.comm; root=0) MPI.Barrier(parallelizer.comm) else - model_dict = Dict(mod_name => _train(model, X, y; mod_name=mod_name, kwargs...) for (mod_name, model) in models) + model_dict = Dict(mod_name => _train(model, X, y; mod_name=mod_name, verbose=verbose, kwargs...) for (mod_name, model) in models) end return model_dict end diff --git a/experiments/setup_env.jl b/experiments/setup_env.jl index f32ddaa5..fc9217a1 100644 --- a/experiments/setup_env.jl +++ b/experiments/setup_env.jl @@ -27,6 +27,8 @@ using TidierData Random.seed!(2023) +ENV["DATADEPS_ALWAYS_ACCEPT"] = "true" # avoid command prompt and just download data + # Scripts: include("experiment.jl") include("data/data.jl") @@ -124,7 +126,7 @@ if any(contains.(ARGS, "n_individuals=")) n_ind_specified = true n_individuals = ARGS[findall(contains.(ARGS, "n_individuals="))][1] |> x -> replace(x, "n_individuals=" => "") else - n_individuals = 25 + n_individuals = 100 end "Number of individuals to use in benchmarking." diff --git a/experiments/synthetic.sh b/experiments/synthetic.sh index 5cf4c359..1ff7acac 100644 --- a/experiments/synthetic.sh +++ b/experiments/synthetic.sh @@ -11,4 +11,4 @@ module load 2023r1 openmpi -srun julia --project=experiments experiments/run_experiments.jl -- data=linearly_separable,moons,cricles output_path=results retrain threaded mpi > experiments/synthetic.log +srun julia --project=experiments experiments/run_experiments.jl -- data=linearly_separable,moons,circles output_path=results retrain threaded mpi > experiments/synthetic.log -- GitLab