Skip to content
Snippets Groups Projects
Commit 723f6ccc authored by Pat Alt's avatar Pat Alt
Browse files

minor things

parent 453d5ae7
No related branches found
No related tags found
1 merge request!7669 initial run including fmnist lenet and new method
...@@ -6,6 +6,7 @@ using CounterfactualExplanations: AbstractParallelizer ...@@ -6,6 +6,7 @@ using CounterfactualExplanations: AbstractParallelizer
Trains all models in a dictionary and returns a dictionary of `ConformalModel` objects. Trains all models in a dictionary and returns a dictionary of `ConformalModel` objects.
""" """
function train_models(models::Dict, X, y; parallelizer::Union{Nothing,AbstractParallelizer}=nothing, train_parallel::Bool=false, kwargs...) function train_models(models::Dict, X, y; parallelizer::Union{Nothing,AbstractParallelizer}=nothing, train_parallel::Bool=false, kwargs...)
verbose = is_multi_processed(parallelizer) ? false : true
if is_multi_processed(parallelizer) && train_parallel if is_multi_processed(parallelizer) && train_parallel
# Split models into groups of approximately equal size: # Split models into groups of approximately equal size:
model_list = [(key, value) for (key, value) in models] model_list = [(key, value) for (key, value) in models]
...@@ -14,7 +15,7 @@ function train_models(models::Dict, X, y; parallelizer::Union{Nothing,AbstractPa ...@@ -14,7 +15,7 @@ function train_models(models::Dict, X, y; parallelizer::Union{Nothing,AbstractPa
# Train models: # Train models:
model_dict = Dict() model_dict = Dict()
for (mod_name, model) in x for (mod_name, model) in x
model_dict[mod_name] = _train(model, X, y; mod_name=mod_name, verbose=false, kwargs...) model_dict[mod_name] = _train(model, X, y; mod_name=mod_name, verbose=verbose, kwargs...)
end end
MPI.Barrier(parallelizer.comm) MPI.Barrier(parallelizer.comm)
output = MPI.gather(output, parallelizer.comm) output = MPI.gather(output, parallelizer.comm)
...@@ -28,7 +29,7 @@ function train_models(models::Dict, X, y; parallelizer::Union{Nothing,AbstractPa ...@@ -28,7 +29,7 @@ function train_models(models::Dict, X, y; parallelizer::Union{Nothing,AbstractPa
model_dict = MPI.bcast(output, parallelizer.comm; root=0) model_dict = MPI.bcast(output, parallelizer.comm; root=0)
MPI.Barrier(parallelizer.comm) MPI.Barrier(parallelizer.comm)
else else
model_dict = Dict(mod_name => _train(model, X, y; mod_name=mod_name, kwargs...) for (mod_name, model) in models) model_dict = Dict(mod_name => _train(model, X, y; mod_name=mod_name, verbose=verbose, kwargs...) for (mod_name, model) in models)
end end
return model_dict return model_dict
end end
......
...@@ -27,6 +27,8 @@ using TidierData ...@@ -27,6 +27,8 @@ using TidierData
Random.seed!(2023) Random.seed!(2023)
ENV["DATADEPS_ALWAYS_ACCEPT"] = "true" # avoid command prompt and just download data
# Scripts: # Scripts:
include("experiment.jl") include("experiment.jl")
include("data/data.jl") include("data/data.jl")
...@@ -124,7 +126,7 @@ if any(contains.(ARGS, "n_individuals=")) ...@@ -124,7 +126,7 @@ if any(contains.(ARGS, "n_individuals="))
n_ind_specified = true n_ind_specified = true
n_individuals = ARGS[findall(contains.(ARGS, "n_individuals="))][1] |> x -> replace(x, "n_individuals=" => "") n_individuals = ARGS[findall(contains.(ARGS, "n_individuals="))][1] |> x -> replace(x, "n_individuals=" => "")
else else
n_individuals = 25 n_individuals = 100
end end
"Number of individuals to use in benchmarking." "Number of individuals to use in benchmarking."
......
...@@ -11,4 +11,4 @@ ...@@ -11,4 +11,4 @@
module load 2023r1 openmpi module load 2023r1 openmpi
srun julia --project=experiments experiments/run_experiments.jl -- data=linearly_separable,moons,cricles output_path=results retrain threaded mpi > experiments/synthetic.log srun julia --project=experiments experiments/run_experiments.jl -- data=linearly_separable,moons,circles output_path=results retrain threaded mpi > experiments/synthetic.log
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment