From 232965b664b356016c59454f2d272107df974fc9 Mon Sep 17 00:00:00 2001 From: pat-alt <altmeyerpat@gmail.com> Date: Fri, 15 Sep 2023 10:34:28 +0200 Subject: [PATCH] set up for grid search --- experiments/experiment.jl | 14 ++++--- experiments/grid_search.jl | 43 ++++++++++++++++++++++ experiments/jobscripts/tuning/synthetic.sh | 14 +++++++ experiments/linearly_separable.jl | 34 +++++++++++++++-- experiments/setup_env.jl | 7 +++- 5 files changed, 101 insertions(+), 11 deletions(-) create mode 100644 experiments/grid_search.jl create mode 100644 experiments/jobscripts/tuning/synthetic.sh diff --git a/experiments/experiment.jl b/experiments/experiment.jl index b161df03..17247504 100644 --- a/experiments/experiment.jl +++ b/experiments/experiment.jl @@ -80,10 +80,12 @@ Run the experiment specified by `exper`. function run_experiment(exper::Experiment; save_output::Bool=true, only_models::Bool=ONLY_MODELS) # Setup - @info "All results will be saved to $(exper.output_path)." - isdir(exper.output_path) || mkdir(exper.output_path) - @info "All parameter choices will be saved to $(exper.params_path)." - isdir(exper.params_path) || mkdir(exper.params_path) + if save_output && !(is_multi_processed(exper) && MPI.Comm_rank(exper.parallelizer.comm) != 0) + @info "All results will be saved to $(exper.output_path)." + isdir(exper.output_path) || mkdir(exper.output_path) + @info "All parameter choices will be saved to $(exper.params_path)." + isdir(exper.params_path) || mkdir(exper.params_path) + end outcome = ExperimentOutcome(exper, nothing, nothing, nothing) # Models @@ -126,8 +128,8 @@ end # Pre-trained models: function pretrained_path(exper::Experiment) - if isfile(joinpath(exper.output_path, "$(exper.save_name)_models.jls")) - @info "Found local pre-trained models in $(exper.output_path) and using those." + if isfile(joinpath(DEFAULT_OUTPUT_PATH, "$(exper.save_name)_models.jls")) + @info "Found local pre-trained models in $(DEFAULT_OUTPUT_PATH) and using those." return exper.output_path else @info "Using artifacts. Models were pre-trained on `julia-$(LATEST_VERSION)` and may not work on other versions." diff --git a/experiments/grid_search.jl b/experiments/grid_search.jl new file mode 100644 index 00000000..a667895d --- /dev/null +++ b/experiments/grid_search.jl @@ -0,0 +1,43 @@ +""" + grid_search( + couterfactual_data::CounterfactualData, + test_data::CounerfactualData; + dataname::String, + tuning_params::NamedTuple, + kwargs..., + ) + +Perform a grid search over the hyperparameters specified by `tuning_params`. Experiments will be run for each combination of hyperparameters. Other keyword arguments are passed to `run_experiment` and fixed for all experiments. +""" +function grid_search( + couterfactual_data::CounterfactualData, + test_data::CounerfactualData; + dataname::String, + tuning_params::NamedTuple, + kwargs..., +) + + # Output path: + grid_search_path = mkpath(joinpath(DEFAULT_OUTPUT_PATH, "grid_search")) + + # Grid setup: + tuning_params = [Pair.(k, vals) for (k, vals) in pairs(tuning_params)] + grid = Iterators.product(tuning_params...) + outcomes = Dict{Any,Any}() + + # Search: + for tuning_params in grid + outcome = run_experiment( + counterfactual_data, test_data; + dataname=dataname, + output_path=grid_search_path, + save_output=false, + tuning_params..., + kwargs..., + ) + outcomes[tuning_params] = outcome + end + + # Save: + Serialization.serialize(joinpath(grid_search_path, "$(replace(lowercase(dataname), " " => "_")).jls"), outcomes) +end \ No newline at end of file diff --git a/experiments/jobscripts/tuning/synthetic.sh b/experiments/jobscripts/tuning/synthetic.sh new file mode 100644 index 00000000..10246f04 --- /dev/null +++ b/experiments/jobscripts/tuning/synthetic.sh @@ -0,0 +1,14 @@ +#!/bin/bash + +#SBATCH --job-name="Grid-search Synthetic (ECCCo)" +#SBATCH --time=03:00:00 +#SBATCH --ntasks=1000 +#SBATCH --cpus-per-task=1 +#SBATCH --partition=compute +#SBATCH --mem-per-cpu=4GB +#SBATCH --account=research-eemcs-insy +#SBATCH --mail-type=END # Set mail type to 'END' to receive a mail when the job finishes. + +module load 2023r1 openmpi + +srun julia --project=experiments experiments/run_experiments.jl -- data=linearly_separable,moons,circles output_path=results mpi grid_search > experiments/synthetic.log diff --git a/experiments/linearly_separable.jl b/experiments/linearly_separable.jl index 59528732..418dce89 100644 --- a/experiments/linearly_separable.jl +++ b/experiments/linearly_separable.jl @@ -1,12 +1,38 @@ +# Data: n_obs = Int(1000 / (1.0 - TEST_SIZE)) counterfactual_data, test_data = train_test_split( load_blobs(n_obs; cluster_std=0.1, center_box=(-1.0 => 1.0)); test_size=TEST_SIZE ) -run_experiment( - counterfactual_data, test_data; - dataname="Linearly Separable", + +# Tuning parameters: +tuning_params = ( + nsamples=[10, 50, 100], + niter_eccco=[20, 50, 100], + Λ=[ + [0.1, 0.1, 0.1], + [0.1, 0.2, 0.2], + [0.1, 0.5, 0.5], + ] +) + +# Parameter choices: +params = ( nsamples=100, niter_eccco=100, Λ=[0.1, 0.2, 0.2], -) \ No newline at end of file +) + +if !GRID_SEARCH + run_experiment( + counterfactual_data, test_data; + dataname="Linearly Separable", + params... + ) +else + grid_search( + counterfactual_data, test_data; + dataname="Linearly Separable", + tuning_params=tuning_params, + ) +end \ No newline at end of file diff --git a/experiments/setup_env.jl b/experiments/setup_env.jl index 0e60c97c..18a44381 100644 --- a/experiments/setup_env.jl +++ b/experiments/setup_env.jl @@ -34,6 +34,7 @@ ENV["DATADEPS_ALWAYS_ACCEPT"] = "true" # avoid command prompt and j # Scripts: include("experiment.jl") +include("grid_search.jl") include("data/data.jl") include("models/models.jl") include("benchmarking/benchmarking.jl") @@ -121,6 +122,7 @@ const CE_MEASURES = [ "Test set proportion." const TEST_SIZE = 0.2 +"Boolean flag to check if upload was specified." const UPLOAD = "upload" ∈ ARGS n_ind_specified = false @@ -135,4 +137,7 @@ end const N_IND = n_individuals "Boolean flag to check if number of individuals was specified." -const N_IND_SPECIFIED = n_ind_specified \ No newline at end of file +const N_IND_SPECIFIED = n_ind_specified + +"Boolean flag to check if grid search was specified." +const GRID_SEARCH = "grid_search" ∈ ARGS \ No newline at end of file -- GitLab