From 048e974ed5514f85efaa6b6c3a77d4e8e40fbc29 Mon Sep 17 00:00:00 2001
From: Pat Alt <55311242+pat-alt@users.noreply.github.com>
Date: Tue, 5 Sep 2023 15:09:22 +0200
Subject: [PATCH] slowly slowly

---
 experiments/gmsc.jl          | 18 +++++++++++-------
 experiments/mnist.jl         | 17 ++++++++++-------
 experiments/models/models.jl |  7 +++++--
 experiments/setup_env.jl     |  2 +-
 4 files changed, 27 insertions(+), 17 deletions(-)

diff --git a/experiments/gmsc.jl b/experiments/gmsc.jl
index f8ecaa4d..f9c8120e 100644
--- a/experiments/gmsc.jl
+++ b/experiments/gmsc.jl
@@ -1,14 +1,18 @@
 counterfactual_data, test_data = train_test_split(load_gmsc(nothing); test_size=TEST_SIZE)
+
+# Default builder:
+n_hidden = 128
+activation = Flux.swish
+builder = MLJFlux.@builder Flux.Chain(
+    Dense(n_in, n_hidden, activation),
+    Dense(n_hidden, n_hidden, activation),
+    Dense(n_hidden, n_out),
+)
+
 run_experiment(
     counterfactual_data, test_data; 
     dataname="GMSC",
-    n_hidden=128,
-    activation = Flux.swish,
-    builder = MLJFlux.@builder Flux.Chain(
-        Dense(n_in, n_hidden, activation),
-        Dense(n_hidden, n_hidden, activation),
-        Dense(n_hidden, n_out),
-    ),
+    builder = builder,
     α=[1.0, 1.0, 1e-1],
     sampling_batch_size=10,
     sampling_steps = 30,
diff --git a/experiments/mnist.jl b/experiments/mnist.jl
index 21cb7b79..372552e2 100644
--- a/experiments/mnist.jl
+++ b/experiments/mnist.jl
@@ -22,16 +22,19 @@ add_models = Dict(
     :lenet5 => lenet5,
 )
 
+# Default builder:
+n_hidden = 128
+activation = Flux.swish
+builder = MLJFlux.@builder Flux.Chain(
+    Dense(n_in, n_hidden, activation),
+    Dense(n_hidden, n_out),
+)
+
 # Run:
 run_experiment(
     counterfactual_data, test_data; 
     dataname="MNIST",
-    n_hidden = 128,
-    activation = Flux.swish,
-    builder= MLJFlux.@builder Flux.Chain(
-        Dense(n_in, n_hidden, activation),
-        Dense(n_hidden, n_out),
-    ),
+    builder= builder,
     𝒟x = Uniform(-1.0, 1.0),
     α = [1.0,1.0,1e-2],
     sampling_batch_size = 10,
@@ -42,6 +45,6 @@ run_experiment(
     nmin = 10,
     use_variants = false,
     use_class_loss = true,
-    add_models = add_models,
+    additional_models=add_models,
     epochs = 10,
 )
\ No newline at end of file
diff --git a/experiments/models/models.jl b/experiments/models/models.jl
index 89975bcd..bd2d5026 100644
--- a/experiments/models/models.jl
+++ b/experiments/models/models.jl
@@ -11,6 +11,8 @@ function prepare_models(exp::Experiment)
     if !exp.use_pretrained
         if isnothing(exp.builder)
             builder = default_builder()
+        else
+            builder = exp.builder
         end
         # Default models:
         if isnothing(exp.models)
@@ -25,6 +27,7 @@ function prepare_models(exp::Experiment)
                 use_ensembling=exp.use_ensembling,
                 finaliser=exp.finaliser,
                 loss=exp.loss,
+                epochs=exp.epochs,
             )
         end
         # Additional models:
@@ -36,10 +39,10 @@ function prepare_models(exp::Experiment)
                     batch_size=batch_size(exp),
                     finaliser=exp.finaliser,
                     loss=exp.loss,
+                    epochs=exp.epochs,
                 )
             end
-            add_models = Dict(k => mod(;batch_size=batch_size(exp), ) for (k, mod) in exp.additional_models)
-            models = merge!(models, exp.additional_models)
+            models = merge(models, add_models)
         end
         @info "Training models."
         model_dict = train_models(models, X, labels; cov=exp.coverage)
diff --git a/experiments/setup_env.jl b/experiments/setup_env.jl
index 1d6a3be3..c6657bdd 100644
--- a/experiments/setup_env.jl
+++ b/experiments/setup_env.jl
@@ -12,7 +12,7 @@ using CounterfactualExplanations.Parallelization
 using CSV
 using Dates
 using DataFrames
-using Distributions: Normal, Distribution, Categorical
+using Distributions: Normal, Distribution, Categorical, Uniform
 using ECCCo
 using Flux
 using JointEnergyModels
-- 
GitLab