diff --git a/experiments/Manifest.toml b/experiments/Manifest.toml index 244c5a2f918b6a0a3c502f6d1bfd4c480d95219b..a82fd00b9068dee1a0d391aeeb7675f54015117d 100644 --- a/experiments/Manifest.toml +++ b/experiments/Manifest.toml @@ -444,9 +444,11 @@ version = "0.6.2" [[deps.CounterfactualExplanations]] deps = ["CSV", "CUDA", "CategoricalArrays", "ChainRulesCore", "DataFrames", "DecisionTree", "Distributions", "EvoTrees", "Flux", "LaplaceRedux", "LazyArtifacts", "LinearAlgebra", "Logging", "MLDatasets", "MLJBase", "MLJDecisionTreeInterface", "MLJModels", "MLUtils", "MultivariateStats", "NearestNeighborModels", "PackageExtensionCompat", "Parameters", "PrecompileTools", "ProgressMeter", "Random", "Serialization", "Statistics", "StatsBase", "Tables", "UUIDs", "cuDNN"] -git-tree-sha1 = "04acac995c56d1609277848e3bdd035e89b712c2" +git-tree-sha1 = "108bf1c3a0cefdcc53bda5482ea1aa6d8ee86382" +repo-rev = "main" +repo-url = "https://github.com/JuliaTrustworthyAI/CounterfactualExplanations.jl.git" uuid = "2f13d31b-18db-44c1-bc43-ebaf2cff0be0" -version = "0.1.21" +version = "0.1.23" [deps.CounterfactualExplanations.extensions] MPIExt = "MPI" diff --git a/experiments/experiment.jl b/experiments/experiment.jl index baf73b49845f7109dc29195860254bca3791fd37..74755b7b9927e16ea2903fa38704efb9a5a4cb2f 100644 --- a/experiments/experiment.jl +++ b/experiments/experiment.jl @@ -54,8 +54,8 @@ end Train the models specified by `exper` and store them in `outcome`. """ -function train_models!(outcome::ExperimentOutcome, exper::Experiment; save_meta::Bool=false) - model_dict = prepare_models(exper) +function train_models!(outcome::ExperimentOutcome, exper::Experiment; save_models::Bool=true, save_meta::Bool=false) + model_dict = prepare_models(exper; save_models=save_models) outcome.model_dict = model_dict if !(is_multi_processed(exper) && MPI.Comm_rank(exper.parallelizer.comm) != 0) meta_model_performance(outcome; save_output=save_meta) @@ -97,10 +97,10 @@ function run_experiment(exper::Experiment; save_output::Bool=true, only_models:: # Model training: if only_models - train_models!(outcome, exper; save_meta=true) + train_models!(outcome, exper; save_models=save_output, save_meta=true) return outcome else - train_models!(outcome, exper) + train_models!(outcome, exper; save_models=save_output) end # Benchmark: diff --git a/experiments/grid_search.jl b/experiments/grid_search.jl index 090acadaf077a8677967be9b552b1b64d1f25b7c..e65f8ae8d1f2af3dcb82eb6570710d7341d2b8ec 100644 --- a/experiments/grid_search.jl +++ b/experiments/grid_search.jl @@ -44,5 +44,7 @@ function grid_search( end # Save: - Serialization.serialize(joinpath(grid_search_path, "$(replace(lowercase(dataname), " " => "_")).jls"), outcomes) + if !(is_multi_processed(exper) && MPI.Comm_rank(exper.parallelizer.comm) != 0) + Serialization.serialize(joinpath(grid_search_path, "$(replace(lowercase(dataname), " " => "_")).jls"), outcomes) + end end \ No newline at end of file diff --git a/experiments/jobscripts/tuning/generators/synthetic.sh b/experiments/jobscripts/tuning/generators/synthetic.sh index 83744419ae5d38aa0d396bf1b73859f5c34bb221..4a4e857d3177bf30f67f0b26356e56ed59cc94f4 100644 --- a/experiments/jobscripts/tuning/generators/synthetic.sh +++ b/experiments/jobscripts/tuning/generators/synthetic.sh @@ -1,7 +1,7 @@ #!/bin/bash #SBATCH --job-name="Grid-search Synthetic (ECCCo)" -#SBATCH --time=06:00:00 +#SBATCH --time=02:00:00 #SBATCH --ntasks=1000 #SBATCH --cpus-per-task=1 #SBATCH --partition=compute diff --git a/experiments/models/models.jl b/experiments/models/models.jl index 0ce02068b8a537edc083f445d9b350cd54fa640b..1b0093b2adab02f378c56456c007db2e61886445 100644 --- a/experiments/models/models.jl +++ b/experiments/models/models.jl @@ -2,7 +2,7 @@ include("additional_models.jl") include("default_models.jl") include("train_models.jl") -function prepare_models(exper::Experiment) +function prepare_models(exper::Experiment; save_models::Bool=true) # Unpack data: X, labels, sampler = prepare_data(exper::Experiment) @@ -75,7 +75,7 @@ function prepare_models(exper::Experiment) end # Save models: - if !(is_multi_processed(exper) && MPI.Comm_rank(exper.parallelizer.comm) != 0) + if save_models && !(is_multi_processed(exper) && MPI.Comm_rank(exper.parallelizer.comm) != 0) @info "Saving models to $(joinpath(pretrained_path(exper), "$(exper.save_name)_models.jls"))." Serialization.serialize(joinpath(pretrained_path(exper), "$(exper.save_name)_models.jls"), model_dict) end