From 9ee71bc1d7fb68cbcff7efb073d099e532909af0 Mon Sep 17 00:00:00 2001
From: Pat Alt <55311242+pat-alt@users.noreply.github.com>
Date: Fri, 1 Sep 2023 15:15:01 +0200
Subject: [PATCH] sorted out anonymous function issue

---
 experiments/experiment.jl      |  2 +-
 experiments/gmsc.jl            |  2 ++
 experiments/mnist.jl           |  1 +
 experiments/post_processing.jl | 24 +++++++++++++++++-------
 src/generator.jl               | 19 +++++++------------
 5 files changed, 28 insertions(+), 20 deletions(-)

diff --git a/experiments/experiment.jl b/experiments/experiment.jl
index 469a123e..5faf62f9 100644
--- a/experiments/experiment.jl
+++ b/experiments/experiment.jl
@@ -21,7 +21,7 @@ Base.@kwdef struct Experiment
     use_ensembling::Bool = true
     coverage::Float64 = DEFAULT_COVERAGE
     generators::Union{Nothing,Dict} = nothing
-    n_individuals::Int = 50
+    n_individuals::Int = 25
     ce_measures::AbstractArray = CE_MEASURES
     model_measures::Dict = MODEL_MEASURES
     use_class_loss::Bool = false
diff --git a/experiments/gmsc.jl b/experiments/gmsc.jl
index 132e02f3..c1ca3c2b 100644
--- a/experiments/gmsc.jl
+++ b/experiments/gmsc.jl
@@ -14,4 +14,6 @@ run_experiment(
     use_ensembling = true,
     Λ=[0.1, 0.5, 0.5],
     opt = Flux.Optimise.Descent(0.05),
+    n_individuals = 10,
+    use_variants = false, 
 )
\ No newline at end of file
diff --git a/experiments/mnist.jl b/experiments/mnist.jl
index 557143e9..1b561b22 100644
--- a/experiments/mnist.jl
+++ b/experiments/mnist.jl
@@ -55,4 +55,5 @@ run_experiment(
     sampling_steps=25,
     use_ensembling = true,
     generators = generator_dict,
+    n_individuals = 5
 )
\ No newline at end of file
diff --git a/experiments/post_processing.jl b/experiments/post_processing.jl
index 7e8f8e63..83672366 100644
--- a/experiments/post_processing.jl
+++ b/experiments/post_processing.jl
@@ -23,6 +23,7 @@ function meta_model(outcome::ExperimentOutcome; save_output::Bool=false)
     # Unpack:
     exp = outcome.exp
     n_obs, batch_size = meta_data(exp)
+    model_dict = outcome.model_dict
 
     params = DataFrame(
         Dict(
@@ -30,13 +31,13 @@ function meta_model(outcome::ExperimentOutcome; save_output::Bool=false)
             :batch_size => batch_size,
             :dataname => exp.dataname,
             :sgld_batch_size => exp.sampling_batch_size,
-            # :epochs => exp.epochs,
-            # :n_hidden => n_hidden,
-            # :n_layers => length(model_dict["MLP"].fitresult[1][1]) - 1,
-            # :activation => string(activation),
-            # :n_ens => n_ens,
-            # :lambda => string(α[3]),
-            # :jem_sampling_steps => jem.sampling_steps,
+            :epochs => exp.epochs,
+            :n_hidden => exp.n_hidden,
+            :n_layers => length(model_dict["MLP"].fitresult[1][1]) - 1,
+            :activation => string(activation),
+            :n_ens => exp.n_ens,
+            :lambda => exp.string(exp.α[3]),
+            :jem_sampling_steps => exp.sampling_steps,
         )
     )
    
@@ -55,6 +56,8 @@ function meta_generators(outcome::ExperimentOutcome; save_output::Bool=false)
     # Unpack:
     exp = outcome.exp
     generator_dict = outcome.generator_dict
+    Λ = exp.Λ
+    Λ_Δ = exp.Λ_Δ
 
     # Output:
     opt = first(values(generator_dict)).opt
@@ -63,6 +66,13 @@ function meta_generators(outcome::ExperimentOutcome; save_output::Bool=false)
             :opt => string(typeof(opt)),
             :eta => opt.eta,
             :dataname => dataname,
+            :lambda_1 => string(Λ[1]),
+            :lambda_2 => string(Λ[2]),
+            :lambda_3 => string(Λ[3]),
+            :lambda_1_Δ => string(Λ_Δ[1]),
+            :lambda_2_Δ => string(Λ_Δ[2]),
+            :lambda_3_Δ => string(Λ_Δ[3]),
+            :n_individuals => exp.n_individuals,
         )
     )
 
diff --git a/src/generator.jl b/src/generator.jl
index b4d986c5..1939d565 100644
--- a/src/generator.jl
+++ b/src/generator.jl
@@ -26,19 +26,14 @@ function ECCCoGenerator(;
         loss_fun = nothing
     end
 
-    # Set size penalty
-    _set_size_penalty = (ce::AbstractCounterfactualExplanation) -> ECCCo.set_size_penalty(ce; κ=κ, temp=temp)
+    _energy_penalty =
+        use_energy_delta ? (ECCCo.energy_delta, (n=nsamples, nmin=nmin)) : (ECCCo.distance_from_energy, (n=nsamples, nmin=nmin))
 
-    # Energy penalty
-    _energy_penalty = function(ce::AbstractCounterfactualExplanation)
-        if use_energy_delta
-            return ECCCo.energy_delta(ce; n=nsamples, nmin=nmin) 
-        else
-            return ECCCo.distance_from_energy(ce; n=nsamples, nmin=nmin)
-        end
-    end
-
-    _penalties = [Objectives.distance_l1, _set_size_penalty, _energy_penalty]
+    _penalties = [
+        (Objectives.distance_l1, []), 
+        (ECCCo.set_size_penalty, (κ=κ, temp=temp)),
+        _energy_penalty,
+    ]
     λ = λ isa AbstractFloat ? [0.0, λ, λ] : λ
 
     # Generator
-- 
GitLab