From 8bd0aa848d755efc4b31b8d57155106229d5dcf0 Mon Sep 17 00:00:00 2001
From: Pat Alt <55311242+pat-alt@users.noreply.github.com>
Date: Tue, 19 Sep 2023 10:44:48 +0200
Subject: [PATCH] downgraded cali spec a little

---
 experiments/california_housing.jl | 15 ++++++++++++++-
 experiments/models/models.jl      |  4 +++-
 2 files changed, 17 insertions(+), 2 deletions(-)

diff --git a/experiments/california_housing.jl b/experiments/california_housing.jl
index 02c7ba85..57e5068b 100644
--- a/experiments/california_housing.jl
+++ b/experiments/california_housing.jl
@@ -6,7 +6,20 @@ counterfactual_data, test_data = train_test_split(load_california_housing(nothin
 model_tuning_params = DEFAULT_MODEL_TUNING_LARGE
 
 # Tuning parameters:
-tuning_params = DEFAULT_GENERATOR_TUNING
+tuning_params = (
+    nsamples=[10, 30],
+    niter_eccco=[10, 30],
+    Λ=[
+        [0.1, 0.1, 0.1],
+        [0.1, 0.2, 0.2],
+        [0.1, 0.5, 0.5],
+    ],
+    reg_strength=[0.0, 0.1, 0.5],
+    opt=[
+        Flux.Optimise.Descent(0.1),
+        Flux.Optimise.Descent(0.01),
+    ],
+)
 
 # Parameter choices:
 params = (
diff --git a/experiments/models/models.jl b/experiments/models/models.jl
index d063ca77..c446b524 100644
--- a/experiments/models/models.jl
+++ b/experiments/models/models.jl
@@ -75,7 +75,9 @@ function prepare_models(exper::Experiment; save_models::Bool=true)
     end
 
     # Save models:
-    if save_models && !(is_multi_processed(exper) && MPI.Comm_rank(exper.parallelizer.comm) != 0)
+    local_models_exist = isfile(joinpath(DEFAULT_OUTPUT_PATH, "$(exper.save_name)_models.jls"))
+    on_root_process = !(is_multi_processed(exper) && MPI.Comm_rank(exper.parallelizer.comm) != 0)
+    if save_models && on_root_process && !local_models_exist
         @info "Saving models to $(joinpath(exper.output_path , "$(exper.save_name)_models.jls"))."
         Serialization.serialize(joinpath(exper.output_path, "$(exper.save_name)_models.jls"), model_dict)
     end
-- 
GitLab