From 8bd0aa848d755efc4b31b8d57155106229d5dcf0 Mon Sep 17 00:00:00 2001 From: Pat Alt <55311242+pat-alt@users.noreply.github.com> Date: Tue, 19 Sep 2023 10:44:48 +0200 Subject: [PATCH] downgraded cali spec a little --- experiments/california_housing.jl | 15 ++++++++++++++- experiments/models/models.jl | 4 +++- 2 files changed, 17 insertions(+), 2 deletions(-) diff --git a/experiments/california_housing.jl b/experiments/california_housing.jl index 02c7ba85..57e5068b 100644 --- a/experiments/california_housing.jl +++ b/experiments/california_housing.jl @@ -6,7 +6,20 @@ counterfactual_data, test_data = train_test_split(load_california_housing(nothin model_tuning_params = DEFAULT_MODEL_TUNING_LARGE # Tuning parameters: -tuning_params = DEFAULT_GENERATOR_TUNING +tuning_params = ( + nsamples=[10, 30], + niter_eccco=[10, 30], + Λ=[ + [0.1, 0.1, 0.1], + [0.1, 0.2, 0.2], + [0.1, 0.5, 0.5], + ], + reg_strength=[0.0, 0.1, 0.5], + opt=[ + Flux.Optimise.Descent(0.1), + Flux.Optimise.Descent(0.01), + ], +) # Parameter choices: params = ( diff --git a/experiments/models/models.jl b/experiments/models/models.jl index d063ca77..c446b524 100644 --- a/experiments/models/models.jl +++ b/experiments/models/models.jl @@ -75,7 +75,9 @@ function prepare_models(exper::Experiment; save_models::Bool=true) end # Save models: - if save_models && !(is_multi_processed(exper) && MPI.Comm_rank(exper.parallelizer.comm) != 0) + local_models_exist = isfile(joinpath(DEFAULT_OUTPUT_PATH, "$(exper.save_name)_models.jls")) + on_root_process = !(is_multi_processed(exper) && MPI.Comm_rank(exper.parallelizer.comm) != 0) + if save_models && on_root_process && !local_models_exist @info "Saving models to $(joinpath(exper.output_path , "$(exper.save_name)_models.jls"))." Serialization.serialize(joinpath(exper.output_path, "$(exper.save_name)_models.jls"), model_dict) end -- GitLab