From d312a6bfe5e753a17f276b9a7c8a5fc1cc9c7da5 Mon Sep 17 00:00:00 2001
From: Pat Alt <55311242+pat-alt@users.noreply.github.com>
Date: Sun, 17 Sep 2023 20:44:09 +0200
Subject: [PATCH] some stuff

---
 experiments/Manifest.toml                             | 6 ++++--
 experiments/experiment.jl                             | 8 ++++----
 experiments/grid_search.jl                            | 4 +++-
 experiments/jobscripts/tuning/generators/synthetic.sh | 2 +-
 experiments/models/models.jl                          | 4 ++--
 5 files changed, 14 insertions(+), 10 deletions(-)

diff --git a/experiments/Manifest.toml b/experiments/Manifest.toml
index 244c5a2f..a82fd00b 100644
--- a/experiments/Manifest.toml
+++ b/experiments/Manifest.toml
@@ -444,9 +444,11 @@ version = "0.6.2"
 
 [[deps.CounterfactualExplanations]]
 deps = ["CSV", "CUDA", "CategoricalArrays", "ChainRulesCore", "DataFrames", "DecisionTree", "Distributions", "EvoTrees", "Flux", "LaplaceRedux", "LazyArtifacts", "LinearAlgebra", "Logging", "MLDatasets", "MLJBase", "MLJDecisionTreeInterface", "MLJModels", "MLUtils", "MultivariateStats", "NearestNeighborModels", "PackageExtensionCompat", "Parameters", "PrecompileTools", "ProgressMeter", "Random", "Serialization", "Statistics", "StatsBase", "Tables", "UUIDs", "cuDNN"]
-git-tree-sha1 = "04acac995c56d1609277848e3bdd035e89b712c2"
+git-tree-sha1 = "108bf1c3a0cefdcc53bda5482ea1aa6d8ee86382"
+repo-rev = "main"
+repo-url = "https://github.com/JuliaTrustworthyAI/CounterfactualExplanations.jl.git"
 uuid = "2f13d31b-18db-44c1-bc43-ebaf2cff0be0"
-version = "0.1.21"
+version = "0.1.23"
 
     [deps.CounterfactualExplanations.extensions]
     MPIExt = "MPI"
diff --git a/experiments/experiment.jl b/experiments/experiment.jl
index baf73b49..74755b7b 100644
--- a/experiments/experiment.jl
+++ b/experiments/experiment.jl
@@ -54,8 +54,8 @@ end
 
 Train the models specified by `exper` and store them in `outcome`.
 """
-function train_models!(outcome::ExperimentOutcome, exper::Experiment; save_meta::Bool=false)
-    model_dict = prepare_models(exper)
+function train_models!(outcome::ExperimentOutcome, exper::Experiment; save_models::Bool=true, save_meta::Bool=false)
+    model_dict = prepare_models(exper; save_models=save_models)
     outcome.model_dict = model_dict
     if !(is_multi_processed(exper) && MPI.Comm_rank(exper.parallelizer.comm) != 0)
         meta_model_performance(outcome; save_output=save_meta)
@@ -97,10 +97,10 @@ function run_experiment(exper::Experiment; save_output::Bool=true, only_models::
 
     # Model training:
     if only_models
-        train_models!(outcome, exper; save_meta=true)
+        train_models!(outcome, exper; save_models=save_output, save_meta=true)
         return outcome
     else
-        train_models!(outcome, exper)
+        train_models!(outcome, exper; save_models=save_output)
     end
 
     # Benchmark:
diff --git a/experiments/grid_search.jl b/experiments/grid_search.jl
index 090acada..e65f8ae8 100644
--- a/experiments/grid_search.jl
+++ b/experiments/grid_search.jl
@@ -44,5 +44,7 @@ function grid_search(
     end
 
     # Save:
-    Serialization.serialize(joinpath(grid_search_path, "$(replace(lowercase(dataname), " " => "_")).jls"), outcomes)
+    if !(is_multi_processed(exper) && MPI.Comm_rank(exper.parallelizer.comm) != 0)
+        Serialization.serialize(joinpath(grid_search_path, "$(replace(lowercase(dataname), " " => "_")).jls"), outcomes)
+    end
 end
\ No newline at end of file
diff --git a/experiments/jobscripts/tuning/generators/synthetic.sh b/experiments/jobscripts/tuning/generators/synthetic.sh
index 83744419..4a4e857d 100644
--- a/experiments/jobscripts/tuning/generators/synthetic.sh
+++ b/experiments/jobscripts/tuning/generators/synthetic.sh
@@ -1,7 +1,7 @@
 #!/bin/bash
 
 #SBATCH --job-name="Grid-search Synthetic (ECCCo)"
-#SBATCH --time=06:00:00
+#SBATCH --time=02:00:00
 #SBATCH --ntasks=1000
 #SBATCH --cpus-per-task=1
 #SBATCH --partition=compute
diff --git a/experiments/models/models.jl b/experiments/models/models.jl
index 0ce02068..1b0093b2 100644
--- a/experiments/models/models.jl
+++ b/experiments/models/models.jl
@@ -2,7 +2,7 @@ include("additional_models.jl")
 include("default_models.jl")
 include("train_models.jl")
 
-function prepare_models(exper::Experiment)
+function prepare_models(exper::Experiment; save_models::Bool=true)
 
     # Unpack data:
     X, labels, sampler = prepare_data(exper::Experiment)
@@ -75,7 +75,7 @@ function prepare_models(exper::Experiment)
     end
 
     # Save models:
-    if !(is_multi_processed(exper) && MPI.Comm_rank(exper.parallelizer.comm) != 0)
+    if save_models && !(is_multi_processed(exper) && MPI.Comm_rank(exper.parallelizer.comm) != 0)
         @info "Saving models to $(joinpath(pretrained_path(exper), "$(exper.save_name)_models.jls"))."
         Serialization.serialize(joinpath(pretrained_path(exper), "$(exper.save_name)_models.jls"), model_dict)
     end
-- 
GitLab