From 232965b664b356016c59454f2d272107df974fc9 Mon Sep 17 00:00:00 2001
From: pat-alt <altmeyerpat@gmail.com>
Date: Fri, 15 Sep 2023 10:34:28 +0200
Subject: [PATCH] set up for grid search

---
 experiments/experiment.jl                  | 14 ++++---
 experiments/grid_search.jl                 | 43 ++++++++++++++++++++++
 experiments/jobscripts/tuning/synthetic.sh | 14 +++++++
 experiments/linearly_separable.jl          | 34 +++++++++++++++--
 experiments/setup_env.jl                   |  7 +++-
 5 files changed, 101 insertions(+), 11 deletions(-)
 create mode 100644 experiments/grid_search.jl
 create mode 100644 experiments/jobscripts/tuning/synthetic.sh

diff --git a/experiments/experiment.jl b/experiments/experiment.jl
index b161df03..17247504 100644
--- a/experiments/experiment.jl
+++ b/experiments/experiment.jl
@@ -80,10 +80,12 @@ Run the experiment specified by `exper`.
 function run_experiment(exper::Experiment; save_output::Bool=true, only_models::Bool=ONLY_MODELS)
     
     # Setup
-    @info "All results will be saved to $(exper.output_path)."
-    isdir(exper.output_path) || mkdir(exper.output_path)
-    @info "All parameter choices will be saved to $(exper.params_path)."
-    isdir(exper.params_path) || mkdir(exper.params_path)
+    if save_output && !(is_multi_processed(exper) && MPI.Comm_rank(exper.parallelizer.comm) != 0)
+        @info "All results will be saved to $(exper.output_path)."
+        isdir(exper.output_path) || mkdir(exper.output_path)
+        @info "All parameter choices will be saved to $(exper.params_path)."
+        isdir(exper.params_path) || mkdir(exper.params_path)
+    end
     outcome = ExperimentOutcome(exper, nothing, nothing, nothing)
 
     # Models
@@ -126,8 +128,8 @@ end
 
 # Pre-trained models:
 function pretrained_path(exper::Experiment)
-    if isfile(joinpath(exper.output_path, "$(exper.save_name)_models.jls"))
-        @info "Found local pre-trained models in $(exper.output_path) and using those."
+    if isfile(joinpath(DEFAULT_OUTPUT_PATH, "$(exper.save_name)_models.jls"))
+        @info "Found local pre-trained models in $(DEFAULT_OUTPUT_PATH) and using those."
         return exper.output_path
     else
         @info "Using artifacts. Models were pre-trained on `julia-$(LATEST_VERSION)` and may not work on other versions."
diff --git a/experiments/grid_search.jl b/experiments/grid_search.jl
new file mode 100644
index 00000000..a667895d
--- /dev/null
+++ b/experiments/grid_search.jl
@@ -0,0 +1,43 @@
+"""
+    grid_search(
+        couterfactual_data::CounterfactualData,
+        test_data::CounerfactualData;
+        dataname::String,
+        tuning_params::NamedTuple,
+        kwargs...,
+    )
+
+Perform a grid search over the hyperparameters specified by `tuning_params`. Experiments will be run for each combination of hyperparameters. Other keyword arguments are passed to `run_experiment` and fixed for all experiments.
+"""
+function grid_search(
+    couterfactual_data::CounterfactualData,
+    test_data::CounerfactualData;
+    dataname::String,
+    tuning_params::NamedTuple,
+    kwargs...,
+)
+
+    # Output path:
+    grid_search_path = mkpath(joinpath(DEFAULT_OUTPUT_PATH, "grid_search"))
+
+    # Grid setup:
+    tuning_params = [Pair.(k, vals) for (k, vals) in pairs(tuning_params)]
+    grid = Iterators.product(tuning_params...)
+    outcomes = Dict{Any,Any}()
+    
+    # Search:
+    for tuning_params in grid
+        outcome = run_experiment(
+            counterfactual_data, test_data;
+            dataname=dataname,
+            output_path=grid_search_path,
+            save_output=false,
+            tuning_params...,
+            kwargs...,
+        )
+        outcomes[tuning_params] = outcome
+    end
+
+    # Save:
+    Serialization.serialize(joinpath(grid_search_path, "$(replace(lowercase(dataname), " " => "_")).jls"), outcomes)
+end
\ No newline at end of file
diff --git a/experiments/jobscripts/tuning/synthetic.sh b/experiments/jobscripts/tuning/synthetic.sh
new file mode 100644
index 00000000..10246f04
--- /dev/null
+++ b/experiments/jobscripts/tuning/synthetic.sh
@@ -0,0 +1,14 @@
+#!/bin/bash
+
+#SBATCH --job-name="Grid-search Synthetic (ECCCo)"
+#SBATCH --time=03:00:00
+#SBATCH --ntasks=1000
+#SBATCH --cpus-per-task=1
+#SBATCH --partition=compute
+#SBATCH --mem-per-cpu=4GB
+#SBATCH --account=research-eemcs-insy
+#SBATCH --mail-type=END     # Set mail type to 'END' to receive a mail when the job finishes. 
+
+module load 2023r1 openmpi
+
+srun julia --project=experiments experiments/run_experiments.jl -- data=linearly_separable,moons,circles output_path=results mpi grid_search > experiments/synthetic.log
diff --git a/experiments/linearly_separable.jl b/experiments/linearly_separable.jl
index 59528732..418dce89 100644
--- a/experiments/linearly_separable.jl
+++ b/experiments/linearly_separable.jl
@@ -1,12 +1,38 @@
+# Data:
 n_obs = Int(1000 / (1.0 - TEST_SIZE))
 counterfactual_data, test_data = train_test_split(
     load_blobs(n_obs; cluster_std=0.1, center_box=(-1.0 => 1.0));
     test_size=TEST_SIZE
 )
-run_experiment(
-    counterfactual_data, test_data; 
-    dataname="Linearly Separable",
+
+# Tuning parameters:
+tuning_params = (
+    nsamples=[10, 50, 100],
+    niter_eccco=[20, 50, 100],
+    Λ=[
+        [0.1, 0.1, 0.1],
+        [0.1, 0.2, 0.2],
+        [0.1, 0.5, 0.5],
+    ]
+)
+
+# Parameter choices:
+params = (
     nsamples=100,
     niter_eccco=100,
     Λ=[0.1, 0.2, 0.2],
-)
\ No newline at end of file
+)
+
+if !GRID_SEARCH 
+    run_experiment(
+        counterfactual_data, test_data; 
+        dataname="Linearly Separable",
+        params...
+    )
+else
+    grid_search(
+        counterfactual_data, test_data;
+        dataname="Linearly Separable",
+        tuning_params=tuning_params,
+    )
+end
\ No newline at end of file
diff --git a/experiments/setup_env.jl b/experiments/setup_env.jl
index 0e60c97c..18a44381 100644
--- a/experiments/setup_env.jl
+++ b/experiments/setup_env.jl
@@ -34,6 +34,7 @@ ENV["DATADEPS_ALWAYS_ACCEPT"] = "true"              # avoid command prompt and j
 
 # Scripts:
 include("experiment.jl")
+include("grid_search.jl")
 include("data/data.jl")
 include("models/models.jl")
 include("benchmarking/benchmarking.jl")
@@ -121,6 +122,7 @@ const CE_MEASURES = [
 "Test set proportion."
 const TEST_SIZE = 0.2
 
+"Boolean flag to check if upload was specified."
 const UPLOAD = "upload" ∈ ARGS
 
 n_ind_specified = false
@@ -135,4 +137,7 @@ end
 const N_IND = n_individuals
 
 "Boolean flag to check if number of individuals was specified."
-const N_IND_SPECIFIED = n_ind_specified
\ No newline at end of file
+const N_IND_SPECIFIED = n_ind_specified
+
+"Boolean flag to check if grid search was specified."
+const GRID_SEARCH = "grid_search" ∈ ARGS
\ No newline at end of file
-- 
GitLab