From 4577b4a49f5ad2ba29414e7bbc248169a4ced7ac Mon Sep 17 00:00:00 2001
From: Pat Alt <55311242+pat-alt@users.noreply.github.com>
Date: Mon, 18 Sep 2023 07:59:14 +0200
Subject: [PATCH] bloody hell

---
 experiments/Manifest.toml    | 6 ++----
 experiments/fmnist.jl        | 7 ++++---
 experiments/grid_search.jl   | 7 ++++++-
 experiments/mnist.jl         | 7 ++++---
 experiments/models/models.jl | 4 ++--
 experiments/setup_env.jl     | 6 +++---
 6 files changed, 21 insertions(+), 16 deletions(-)

diff --git a/experiments/Manifest.toml b/experiments/Manifest.toml
index 286844d4..1277de6e 100644
--- a/experiments/Manifest.toml
+++ b/experiments/Manifest.toml
@@ -444,11 +444,9 @@ version = "0.6.2"
 
 [[deps.CounterfactualExplanations]]
 deps = ["CSV", "CUDA", "CategoricalArrays", "ChainRulesCore", "DataFrames", "DecisionTree", "Distributions", "EvoTrees", "Flux", "LaplaceRedux", "LazyArtifacts", "LinearAlgebra", "Logging", "MLDatasets", "MLJBase", "MLJDecisionTreeInterface", "MLJModels", "MLUtils", "MultivariateStats", "NearestNeighborModels", "PackageExtensionCompat", "Parameters", "PrecompileTools", "ProgressMeter", "Random", "Serialization", "Statistics", "StatsBase", "Tables", "UUIDs", "cuDNN"]
-git-tree-sha1 = "9bcb579703041d8708b179e55c119f150c5565bc"
-repo-rev = "main"
-repo-url = "https://github.com/JuliaTrustworthyAI/CounterfactualExplanations.jl.git"
+git-tree-sha1 = "30cf711962736a6bc5ffc6c7d1b6be6d11d306d9"
 uuid = "2f13d31b-18db-44c1-bc43-ebaf2cff0be0"
-version = "0.1.23"
+version = "0.1.24"
 
     [deps.CounterfactualExplanations.extensions]
     MPIExt = "MPI"
diff --git a/experiments/fmnist.jl b/experiments/fmnist.jl
index d911473b..f750b1c6 100644
--- a/experiments/fmnist.jl
+++ b/experiments/fmnist.jl
@@ -16,8 +16,8 @@ test_data = load_fashion_mnist_test()
 model_tuning_params = DEFAULT_MODEL_TUNING_LARGE
 
 # Tuning parameters:
-tuning_params = DEFAULT_GENERATOR_TUNING[2:end]
-push!(tuning_params.Λ, [0.1, 0.1, 3.0])
+tuning_params = DEFAULT_GENERATOR_TUNING
+tuning_params = (; tuning_params..., Λ=[tuning_params.Λ[2:end]..., [0.1, 0.1, 3.0]])
 
 # Additional models:
 add_models = Dict(
@@ -39,7 +39,8 @@ params = (
     epochs=10,
     nsamples=10,
     nmin=1,
-    niter_eccco=100
+    niter_eccco=100,
+    Λ = [0.1, 0.1, 3.0]
 )
 
 if !GRID_SEARCH
diff --git a/experiments/grid_search.jl b/experiments/grid_search.jl
index e65f8ae8..bb77e6a3 100644
--- a/experiments/grid_search.jl
+++ b/experiments/grid_search.jl
@@ -25,6 +25,11 @@ function grid_search(
     tuning_params = [Pair.(k, vals) for (k, vals) in pairs(tuning_params)]
     grid = Iterators.product(tuning_params...)
     outcomes = Dict{Any,Any}()
+
+    # Save:
+    if !(is_multi_processed(PLZ) && MPI.Comm_rank(PLZ.comm) != 0)
+        Serialization.serialize(joinpath(grid_search_path, "$(replace(lowercase(dataname), " " => "_")).jls"), outcomes)
+    end
     
     # Search:
     counter = 1
@@ -44,7 +49,7 @@ function grid_search(
     end
 
     # Save:
-    if !(is_multi_processed(exper) && MPI.Comm_rank(exper.parallelizer.comm) != 0)
+    if !(is_multi_processed(PLZ) && MPI.Comm_rank(PLZ.comm) != 0)
         Serialization.serialize(joinpath(grid_search_path, "$(replace(lowercase(dataname), " " => "_")).jls"), outcomes)
     end
 end
\ No newline at end of file
diff --git a/experiments/mnist.jl b/experiments/mnist.jl
index 388af47d..3a22eff4 100644
--- a/experiments/mnist.jl
+++ b/experiments/mnist.jl
@@ -16,8 +16,8 @@ test_data = load_mnist_test()
 model_tuning_params = DEFAULT_MODEL_TUNING_LARGE
 
 # Tuning parameters:
-tuning_params = DEFAULT_GENERATOR_TUNING[2:end]
-push!(tuning_params.Λ, [0.1, 0.1, 3.0])
+tuning_params = DEFAULT_GENERATOR_TUNING
+tuning_params = (; tuning_params..., Λ=[tuning_params.Λ[2:end]..., [0.1, 0.1, 3.0]])
 
 # Additional models:
 add_models = Dict(
@@ -39,7 +39,8 @@ params = (
     epochs=10,
     nsamples=10,
     nmin=1,
-    niter_eccco=100
+    niter_eccco=100,
+    Λ=[0.1, 0.1, 3.0]
 )
 
 if !GRID_SEARCH
diff --git a/experiments/models/models.jl b/experiments/models/models.jl
index 1b0093b2..597205db 100644
--- a/experiments/models/models.jl
+++ b/experiments/models/models.jl
@@ -76,8 +76,8 @@ function prepare_models(exper::Experiment; save_models::Bool=true)
 
     # Save models:
     if save_models && !(is_multi_processed(exper) && MPI.Comm_rank(exper.parallelizer.comm) != 0)
-        @info "Saving models to $(joinpath(pretrained_path(exper), "$(exper.save_name)_models.jls"))."
-        Serialization.serialize(joinpath(pretrained_path(exper), "$(exper.save_name)_models.jls"), model_dict)
+        @info "Saving models to $(joinpath(exper.output_path , "$(exper.save_name)_models.jls"))."
+        Serialization.serialize(joinpath(exper.output_path, "$(exper.save_name)_models.jls"), model_dict)
     end
 
     return model_dict
diff --git a/experiments/setup_env.jl b/experiments/setup_env.jl
index d8db488d..c6c4276e 100644
--- a/experiments/setup_env.jl
+++ b/experiments/setup_env.jl
@@ -145,7 +145,7 @@ const N_IND_SPECIFIED = n_ind_specified
 const GRID_SEARCH = "grid_search" ∈ ARGS
 
 "Generator tuning parameters."
-const DEFAULT_GENERATOR_TUNING = (
+DEFAULT_GENERATOR_TUNING = (
     nsamples=[10, 100],
     niter_eccco=[10, 100],
     Λ=[
@@ -160,13 +160,13 @@ const DEFAULT_GENERATOR_TUNING = (
 const TUNE_MODEL = "tune_model" ∈ ARGS
 
 "Model tuning parameters for small datasets."
-const DEFAULT_MODEL_TUNING_SMALL = (
+DEFAULT_MODEL_TUNING_SMALL = (
     n_hidden=[16, 32, 64],
     n_layers=[1, 2, 3],
 )
 
 "Model tuning parameters for large datasets."
-const DEFAULT_MODEL_TUNING_LARGE = (
+DEFAULT_MODEL_TUNING_LARGE = (
     n_hidden=[32, 64, 128, 512],
     n_layers=[2, 3, 5],
 )
-- 
GitLab