From 025485a6033715190ce0b328c3108faada5c34c2 Mon Sep 17 00:00:00 2001
From: Pat Alt <55311242+pat-alt@users.noreply.github.com>
Date: Fri, 15 Sep 2023 15:26:32 +0200
Subject: [PATCH] model tuning

---
 experiments/Manifest.toml            |  8 +++---
 experiments/circles.jl               | 13 +++-------
 experiments/experiment.jl            | 12 ++++++---
 experiments/linearly_separable.jl    | 13 +++-------
 experiments/model_tuning.jl          | 39 ++++++++++++++++++----------
 experiments/models/default_models.jl | 22 ++++++++++++----
 experiments/moons.jl                 | 13 +++-------
 experiments/setup_env.jl             | 31 +++++++++++++++++++++-
 8 files changed, 98 insertions(+), 53 deletions(-)

diff --git a/experiments/Manifest.toml b/experiments/Manifest.toml
index adaf8cc9..f6d05355 100644
--- a/experiments/Manifest.toml
+++ b/experiments/Manifest.toml
@@ -805,9 +805,9 @@ uuid = "9fa8497b-333b-5362-9e8d-4d0656e87820"
 
 [[deps.GLM]]
 deps = ["Distributions", "LinearAlgebra", "Printf", "Reexport", "SparseArrays", "SpecialFunctions", "Statistics", "StatsAPI", "StatsBase", "StatsFuns", "StatsModels"]
-git-tree-sha1 = "97829cfda0df99ddaeaafb5b370d6cab87b7013e"
+git-tree-sha1 = "273bd1cd30768a2fddfa3fd63bbc746ed7249e5f"
 uuid = "38e38edf-8417-5370-95a0-9cbb8c7f171a"
-version = "1.8.3"
+version = "1.9.0"
 
 [[deps.GPUArrays]]
 deps = ["Adapt", "GPUArraysCore", "LLVM", "LinearAlgebra", "Printf", "Random", "Reexport", "Serialization", "Statistics"]
@@ -2389,9 +2389,9 @@ uuid = "bc48ee85-29a4-5162-ae0b-a64e1601d4bc"
 version = "0.3.5"
 
 [[deps.TupleTools]]
-git-tree-sha1 = "c8cdc29448afa1a306419f5d1c7af0854c171c80"
+git-tree-sha1 = "155515ed4c4236db30049ac1495e2969cc06be9d"
 uuid = "9d95972d-f1c8-5527-a6e0-b4b365fa01f6"
-version = "1.4.1"
+version = "1.4.3"
 
 [[deps.URIs]]
 git-tree-sha1 = "b7a5e99f24892b6824a954199a45e9ffcc1c70f0"
diff --git a/experiments/circles.jl b/experiments/circles.jl
index 52b8870b..1bfe1ad5 100644
--- a/experiments/circles.jl
+++ b/experiments/circles.jl
@@ -3,16 +3,11 @@ dataname = "Circles"
 n_obs = Int(1000 / (1.0 - TEST_SIZE))
 counterfactual_data, test_data = train_test_split(load_circles(n_obs; noise=0.05, factor=0.5); test_size=TEST_SIZE)
 
+# Model tuning:
+model_tuning_params = DEFAULT_MODEL_TUNING_SMALL
+
 # Tuning parameters:
-tuning_params = (
-    nsamples=[10, 50, 100],
-    niter_eccco=[20, 50, 100],
-    Λ=[
-        [0.1, 0.1, 0.1],
-        [0.1, 0.2, 0.2],
-        [0.1, 0.5, 0.5],
-    ]
-)
+tuning_params = DEFAULT_GENERATOR_TUNING
 
 # Parameter choices:
 params = (
diff --git a/experiments/experiment.jl b/experiments/experiment.jl
index 17247504..5d31d7f5 100644
--- a/experiments/experiment.jl
+++ b/experiments/experiment.jl
@@ -38,6 +38,7 @@ Base.@kwdef struct Experiment
     train_parallel::Bool = false
     reg_strength::Real = 0.1
     niter_eccco::Union{Nothing,Int} = nothing
+    model_tuning_params::Tuple = DEFAULT_MODEL_TUNING_SMALL
 end
 
 "A container to hold the results of an experiment."
@@ -88,13 +89,18 @@ function run_experiment(exper::Experiment; save_output::Bool=true, only_models::
     end
     outcome = ExperimentOutcome(exper, nothing, nothing, nothing)
 
-    # Models
-    train_models!(outcome, exper)
+    # Model tuning:
+    if TUNE_MODEL
+        mach = tune_model(exper)
+        return mach
+    end
 
+    # Model training:
+    train_models!(outcome, exper)
     # Return if only models are needed:
     !only_models || return outcome
 
-    # Benchmark
+    # Benchmark:
     benchmark!(outcome, exper)
     if is_multi_processed(exper)
         MPI.Barrier(exper.parallelizer.comm)
diff --git a/experiments/linearly_separable.jl b/experiments/linearly_separable.jl
index f159a486..0c98beb0 100644
--- a/experiments/linearly_separable.jl
+++ b/experiments/linearly_separable.jl
@@ -6,16 +6,11 @@ counterfactual_data, test_data = train_test_split(
     test_size=TEST_SIZE
 )
 
+# Model tuning:
+model_tuning_params = DEFAULT_MODEL_TUNING_SMALL
+
 # Tuning parameters:
-tuning_params = (
-    nsamples=[10, 50, 100],
-    niter_eccco=[20, 50, 100],
-    Λ=[
-        [0.1, 0.1, 0.1],
-        [0.1, 0.2, 0.2],
-        [0.1, 0.5, 0.5],
-    ]
-)
+tuning_params = DEFAULT_GENERATOR_TUNING
 
 # Parameter choices:
 params = (
diff --git a/experiments/model_tuning.jl b/experiments/model_tuning.jl
index d9b02379..64157c40 100644
--- a/experiments/model_tuning.jl
+++ b/experiments/model_tuning.jl
@@ -1,15 +1,26 @@
-"An MLP builder that is more easily tunable."
-mutable struct TuningBuilder <: MLJFlux.Builder
-    n_hidden::Int
-    n_layers::Int
-end
-
-"Outer constructor."
-TuningBuilder(;n_hidden=32, n_layers=3) = TuningBuilder(n_hidden, n_layers)
+"""
+    tune_model(exper::Experiment; kwargs...)
 
-function MLJFlux.build(nn::TuningBuilder, rng, n_in, n_out)
-    hidden = ntuple(i -> nn.n_hidden, nn.n_layers)
-    return MLJFlux.build(MLJFlux.MLP(hidden=hidden), rng, n_in, n_out)
+Tunes MLP in place and saves the tuned model to disk.
+"""
+function tune_model(exper::Experiment; kwargs...)
+    if !(is_multi_processed(exper) && MPI.Comm_rank(exper.parallelizer.comm) != 0)
+        @info "Tuning models."
+        # Output path:
+        model_tuning_path = mkpath(joinpath(DEFAULT_OUTPUT_PATH, "tuned_model"))
+        # Simple MLP:
+        mod = NeuralNetworkClassifier(
+            builder=default_builder(),
+            epochs=exper.epochs,
+            batch_size=batch_size(exper),
+            finaliser=exper.finaliser,
+            loss=exper.loss,
+            acceleration=CUDALibs(),
+        )
+        mach = tune_model(mod, X, y; tuning_params=exper.model_tuning_params, measure=exper.model_measures, kwargs...)
+        Serialization.serialize(joinpath(model_tuning_path, "$(exper.save_name).jls"), mach)
+    end
+    return mach
 end
 
 """
@@ -17,7 +28,7 @@ end
 
 Tunes a model by performing a grid search over the parameters specified in `tuning_params`.
 """
-function tune_model(mod::Supervised, X, y; tuning_params::NamedTuple, kwargs...)
+function tune_model(mod::Supervised, X, y; tuning_params::NamedTuple, measure=MODEL_MEASURES, kwargs...)
 
     ranges = []
 
@@ -36,6 +47,7 @@ function tune_model(mod::Supervised, X, y; tuning_params::NamedTuple, kwargs...)
     self_tuning_mod = TunedModel(
         model=mod,
         range=ranges,
+        measure=measure,
         kwargs...
     )
 
@@ -44,4 +56,5 @@ function tune_model(mod::Supervised, X, y; tuning_params::NamedTuple, kwargs...)
 
     return mach
 
-end
\ No newline at end of file
+end
+
diff --git a/experiments/models/default_models.jl b/experiments/models/default_models.jl
index 5393813e..ae009817 100644
--- a/experiments/models/default_models.jl
+++ b/experiments/models/default_models.jl
@@ -1,13 +1,25 @@
+"An MLP builder that is more easily tunable."
+mutable struct TuningBuilder <: MLJFlux.Builder
+    n_hidden::Int
+    n_layers::Int
+    activation::Function
+end
+
+"Outer constructor."
+TuningBuilder(; n_hidden=32, n_layers=3, activation=Flux.swish) = TuningBuilder(n_hidden, n_layers, activation)
+
+function MLJFlux.build(nn::TuningBuilder, rng, n_in, n_out)
+    hidden = ntuple(i -> nn.n_hidden, nn.n_layers)
+    return MLJFlux.build(MLJFlux.MLP(hidden=hidden, σ=nn.activation), rng, n_in, n_out)
+end
+
 """
     default_builder(n_hidden::Int=16, activation::Function=Flux.swish)
 
 Default builder for MLPs.
 """
-function default_builder(n_hidden::Int=16, activation::Function=Flux.swish)
-    builder = MLJFlux.MLP(
-        hidden=(n_hidden, n_hidden, n_hidden),
-        σ=activation
-    )
+function default_builder(n_hidden::Int=16, n_layers::Int=3, activation::Function=Flux.swish)
+    builder = TuningBuilder(n_hidden=n_hidden, n_layers=n_layers, activation=activation)
     return builder
 end
 
diff --git a/experiments/moons.jl b/experiments/moons.jl
index 99306add..ef700ac8 100644
--- a/experiments/moons.jl
+++ b/experiments/moons.jl
@@ -3,16 +3,11 @@ dataname = "Moons"
 n_obs = Int(2500 / (1.0 - TEST_SIZE))
 counterfactual_data, test_data = train_test_split(load_moons(n_obs); test_size=TEST_SIZE)
 
+# Model tuning:
+model_tuning_params = DEFAULT_MODEL_TUNING_SMALL
+
 # Tuning parameters:
-tuning_params = (
-    nsamples=[10, 50, 100],
-    niter_eccco=[20, 50, 100],
-    Λ=[
-        [0.1, 0.1, 0.1],
-        [0.1, 0.2, 0.2],
-        [0.1, 0.5, 0.5],
-    ]
-)
+tuning_params = DEFAULT_GENERATOR_TUNING
 
 # Parameter choices:
 params = (
diff --git a/experiments/setup_env.jl b/experiments/setup_env.jl
index 5eb01a4f..21af5f67 100644
--- a/experiments/setup_env.jl
+++ b/experiments/setup_env.jl
@@ -141,4 +141,33 @@ const N_IND = n_individuals
 const N_IND_SPECIFIED = n_ind_specified
 
 "Boolean flag to check if grid search was specified."
-const GRID_SEARCH = "grid_search" ∈ ARGS
\ No newline at end of file
+const GRID_SEARCH = "grid_search" ∈ ARGS
+
+"Generator tuning parameters."
+const DEFAULT_GENERATOR_TUNING = (
+    nsamples=[10, 100],
+    niter_eccco=[10, 100],
+    Λ=[
+        [0.1, 0.1, 0.1],
+        [0.1, 0.2, 0.2],
+        [0.1, 0.5, 0.5],
+    ],
+    reg_strength=[0.0, 0.1, 0.5],
+)
+
+"Boolean flag to check if model tuning was specified."
+const TUNE_MODEL = "tune_model" ∈ ARGS
+
+"Model tuning parameters for small datasets."
+const DEFAULT_MODEL_TUNING_SMALL = (
+    n_hidden=[16, 32, 64],
+    n_layers=[1, 2, 3],
+    activation=[Flux.relu, Flux.swish],
+)
+
+"Model tuning parameters for large datasets."
+const DEFAULT_MODEL_TUNING_LARGE = (
+    n_hidden=[32, 64, 128],
+    n_layers=[2, 3, 5],
+    activation=[Flux.relu, Flux.swish],
+)
-- 
GitLab