From 9221e51f98f685553c4991734f22853834cd1b8c Mon Sep 17 00:00:00 2001
From: Pat Alt <55311242+pat-alt@users.noreply.github.com>
Date: Fri, 1 Sep 2023 13:36:09 +0200
Subject: [PATCH] more cleaning up

---
 experiments/circles.jl            | 8 +++-----
 experiments/experiment.jl         | 9 ++++++++-
 experiments/gmsc.jl               | 6 +-----
 experiments/linearly_separable.jl | 5 ++++-
 experiments/mnist.jl              | 2 +-
 experiments/models/models.jl      | 3 ++-
 experiments/moons.jl              | 9 +++------
 7 files changed, 22 insertions(+), 20 deletions(-)

diff --git a/experiments/circles.jl b/experiments/circles.jl
index 6cb3a38a..110c22f7 100644
--- a/experiments/circles.jl
+++ b/experiments/circles.jl
@@ -1,14 +1,12 @@
 n_obs = Int(1000 / (1.0 - TEST_SIZE))
 counterfactual_data, test_data = train_test_split(load_circles(n_obs; noise=0.05, factor=0.5); test_size=TEST_SIZE)
 run_experiment(
-    counterfactual_data, test_data; dataname="Circles",
+    counterfactual_data, test_data; 
+    dataname="Circles",
     n_hidden=32,
     α=[1.0, 1.0, 1e-2],
     sampling_batch_size=nothing,
     sampling_steps=20,
-    λ₁=0.25,
-    λ₂ = 0.75,
-    λ₃ = 0.75,
+    Λ=[0.25, 0.75, 0.75],
     opt=Flux.Optimise.Descent(0.01),
-    use_class_loss = false,
 )
\ No newline at end of file
diff --git a/experiments/experiment.jl b/experiments/experiment.jl
index 60df5148..469a123e 100644
--- a/experiments/experiment.jl
+++ b/experiments/experiment.jl
@@ -11,13 +11,20 @@ Base.@kwdef struct Experiment
     builder::Union{Nothing,MLJFlux.Builder} = nothing
     𝒟x::Distribution = Normal()
     sampling_batch_size::Int = 50
+    sampling_steps::Int = 50
     min_batch_size::Int = 128
+    epochs::Int = 100
+    n_hidden::Int = 32
+    activation::Function = Flux.relu
+    α::AbstractArray = [1.0, 1.0, 1e-1]
+    n_ens::Int = 5
+    use_ensembling::Bool = true
     coverage::Float64 = DEFAULT_COVERAGE
     generators::Union{Nothing,Dict} = nothing
     n_individuals::Int = 50
     ce_measures::AbstractArray = CE_MEASURES
     model_measures::Dict = MODEL_MEASURES
-    use_class_loss::Bool = true
+    use_class_loss::Bool = false
     use_variants::Bool = true
     Λ::AbstractArray = [0.25, 0.75, 0.75]
     Λ_Δ::AbstractArray = Λ
diff --git a/experiments/gmsc.jl b/experiments/gmsc.jl
index 8b53e93e..132e02f3 100644
--- a/experiments/gmsc.jl
+++ b/experiments/gmsc.jl
@@ -12,10 +12,6 @@ run_experiment(
     sampling_batch_size=nothing,
     sampling_steps = 30,
     use_ensembling = true,
-    λ₁ = 0.1,
-    λ₂ = 0.5,
-    λ₃ = 0.5,
+    Λ=[0.1, 0.5, 0.5],
     opt = Flux.Optimise.Descent(0.05),
-    use_class_loss=false,
-    use_variants=false,
 )
\ No newline at end of file
diff --git a/experiments/linearly_separable.jl b/experiments/linearly_separable.jl
index 4adc399a..4fcf9ea7 100644
--- a/experiments/linearly_separable.jl
+++ b/experiments/linearly_separable.jl
@@ -3,4 +3,7 @@ counterfactual_data, test_data = train_test_split(
     load_blobs(n_obs; cluster_std=0.1, center_box=(-1.0 => 1.0));
     test_size=TEST_SIZE
 )
-run_experiment(counterfactual_data, test_data; dataname="Linearly Separable")
\ No newline at end of file
+run_experiment(
+    counterfactual_data, test_data; 
+    dataname="Linearly Separable"
+)
\ No newline at end of file
diff --git a/experiments/mnist.jl b/experiments/mnist.jl
index 39e13d40..557143e9 100644
--- a/experiments/mnist.jl
+++ b/experiments/mnist.jl
@@ -52,7 +52,7 @@ run_experiment(
     𝒟x = Uniform(-1.0, 1.0),
     α = [1.0,1.0,1e-2],
     sampling_batch_size = 10,
-    ssampling_steps=25,
+    sampling_steps=25,
     use_ensembling = true,
     generators = generator_dict,
 )
\ No newline at end of file
diff --git a/experiments/models/models.jl b/experiments/models/models.jl
index 16bffaeb..472e1c16 100644
--- a/experiments/models/models.jl
+++ b/experiments/models/models.jl
@@ -17,7 +17,8 @@ function prepare_models(exp::Experiment)
             models = default_models(;
                 sampler=sampler,
                 builder=builder,
-                batch_size=batch_size(exp)
+                batch_size=batch_size(exp),
+                sampling_steps=exp.sampling_steps,
             )
         end
         @info "Training models."
diff --git a/experiments/moons.jl b/experiments/moons.jl
index bb660842..02034380 100644
--- a/experiments/moons.jl
+++ b/experiments/moons.jl
@@ -1,16 +1,13 @@
 n_obs = Int(2500 / (1.0 - TEST_SIZE))
 counterfactual_data, test_data = train_test_split(load_moons(n_obs); test_size=TEST_SIZE)
 run_experiment(
-    counterfactual_data, test_data; dataname="Moons",
+    counterfactual_data, test_data; 
+    dataname="Moons",
     epochs=500,
     n_hidden=32,
     activation = Flux.relu,
-    α=[1.0, 1.0, 1e-1],
     sampling_batch_size=10,
     sampling_steps=30,
-    λ₁=0.25,
-    λ₂=0.75,
-    λ₃=0.75,
+    Λ=[0.25, 0.75, 0.75],
     opt=Flux.Optimise.Descent(0.05),
-    use_class_loss=false
 )
\ No newline at end of file
-- 
GitLab