diff --git a/experiments/california_housing.jl b/experiments/california_housing.jl index 02c7ba856511896208fa52f3d6418121aa71a9bf..57e5068b991f552f942bea8c735e11d6f2e0a581 100644 --- a/experiments/california_housing.jl +++ b/experiments/california_housing.jl @@ -6,7 +6,20 @@ counterfactual_data, test_data = train_test_split(load_california_housing(nothin model_tuning_params = DEFAULT_MODEL_TUNING_LARGE # Tuning parameters: -tuning_params = DEFAULT_GENERATOR_TUNING +tuning_params = ( + nsamples=[10, 30], + niter_eccco=[10, 30], + Λ=[ + [0.1, 0.1, 0.1], + [0.1, 0.2, 0.2], + [0.1, 0.5, 0.5], + ], + reg_strength=[0.0, 0.1, 0.5], + opt=[ + Flux.Optimise.Descent(0.1), + Flux.Optimise.Descent(0.01), + ], +) # Parameter choices: params = ( diff --git a/experiments/models/models.jl b/experiments/models/models.jl index d063ca77ad216469a5edd8289e5e1e2989ac424c..c446b524fe977aa45ec64022e19177c6558c700c 100644 --- a/experiments/models/models.jl +++ b/experiments/models/models.jl @@ -75,7 +75,9 @@ function prepare_models(exper::Experiment; save_models::Bool=true) end # Save models: - if save_models && !(is_multi_processed(exper) && MPI.Comm_rank(exper.parallelizer.comm) != 0) + local_models_exist = isfile(joinpath(DEFAULT_OUTPUT_PATH, "$(exper.save_name)_models.jls")) + on_root_process = !(is_multi_processed(exper) && MPI.Comm_rank(exper.parallelizer.comm) != 0) + if save_models && on_root_process && !local_models_exist @info "Saving models to $(joinpath(exper.output_path , "$(exper.save_name)_models.jls"))." Serialization.serialize(joinpath(exper.output_path, "$(exper.save_name)_models.jls"), model_dict) end