From d09efd11fa2d0c51cc20fe0eb854127d83ea7a23 Mon Sep 17 00:00:00 2001
From: Pat Alt <55311242+pat-alt@users.noreply.github.com>
Date: Wed, 13 Sep 2023 10:41:54 +0200
Subject: [PATCH] matame

---
 experiments/benchmarking/benchmarking.jl | 17 +++++++++--------
 experiments/experiment.jl                |  1 +
 experiments/gmsc.jl                      |  5 +++--
 src/generator.jl                         |  7 +++++--
 src/penalties.jl                         |  3 +--
 5 files changed, 19 insertions(+), 14 deletions(-)

diff --git a/experiments/benchmarking/benchmarking.jl b/experiments/benchmarking/benchmarking.jl
index 64f864b3..ddfacbce 100644
--- a/experiments/benchmarking/benchmarking.jl
+++ b/experiments/benchmarking/benchmarking.jl
@@ -4,6 +4,7 @@ function default_generators(;
     use_variants::Bool=true,
     use_class_loss::Bool=false,
     opt=Flux.Optimise.Descent(0.01),
+    niter_eccco::Union{Nothing,Int}=nothing,
     nsamples::Union{Nothing,Int}=nothing,
     nmin::Union{Nothing,Int}=nothing,
     reg_strength::Real=0.5,
@@ -18,20 +19,20 @@ function default_generators(;
             "Wachter" => WachterGenerator(λ=λ₁, opt=opt),
             "REVISE" => REVISEGenerator(λ=λ₁, opt=opt),
             "Schut" => GreedyGenerator(),
-            "ECCCo" => ECCCoGenerator(λ=Λ, opt=opt, use_class_loss=use_class_loss, nsamples=nsamples, nmin=nmin),
-            "ECCCo (no CP)" => ECCCoGenerator(λ=[λ₁, 0.0, λ₃], opt=opt, use_class_loss=use_class_loss, nsamples=nsamples, nmin=nmin),
-            "ECCCo (no EBM)" => ECCCoGenerator(λ=[λ₁, λ₂, 0.0], opt=opt, use_class_loss=use_class_loss, nsamples=nsamples, nmin=nmin),
-            "ECCCo-Δ" => ECCCoGenerator(λ=Λ_Δ, opt=opt, use_class_loss=use_class_loss, use_energy_delta=true, nsamples=nsamples, nmin=nmin, reg_strength=reg_strength),
-            "ECCCo-Δ (no CP)" => ECCCoGenerator(λ=[λ₁_Δ, 0.0, λ₃_Δ], opt=opt, use_class_loss=use_class_loss, use_energy_delta=true, nsamples=nsamples, nmin=nmin, reg_strength=reg_strength),
-            "ECCCo-Δ (no EBM)" => ECCCoGenerator(λ=[λ₁_Δ, λ₂_Δ, 0.0], opt=opt, use_class_loss=use_class_loss, use_energy_delta=true, nsamples=nsamples, nmin=nmin, reg_strength=reg_strength),
+            "ECCCo" => ECCCoGenerator(λ=Λ, opt=opt, use_class_loss=use_class_loss, nsamples=nsamples, nmin=nmin, niter=niter_eccco),
+            "ECCCo (no CP)" => ECCCoGenerator(λ=[λ₁, 0.0, λ₃], opt=opt, use_class_loss=use_class_loss, nsamples=nsamples, nmin=nmin, niter=niter_eccco),
+            "ECCCo (no EBM)" => ECCCoGenerator(λ=[λ₁, λ₂, 0.0], opt=opt, use_class_loss=use_class_loss, nsamples=nsamples, nmin=nmin, niter=niter_eccco),
+            "ECCCo-Δ" => ECCCoGenerator(λ=Λ_Δ, opt=opt, use_class_loss=use_class_loss, use_energy_delta=true, nsamples=nsamples, nmin=nmin, niter=niter_eccco, reg_strength = reg_strength),
+            "ECCCo-Δ (no CP)" => ECCCoGenerator(λ=[λ₁_Δ, 0.0, λ₃_Δ], opt=opt, use_class_loss=use_class_loss, use_energy_delta=true, nsamples=nsamples, nmin=nmin, niter=niter_eccco, reg_strength=reg_strength),
+            "ECCCo-Δ (no EBM)" => ECCCoGenerator(λ=[λ₁_Δ, λ₂_Δ, 0.0], opt=opt, use_class_loss=use_class_loss, use_energy_delta=true, nsamples=nsamples, nmin=nmin, niter=niter_eccco, reg_strength=reg_strength),
         )
     else
         generator_dict = Dict(
             "Wachter" => WachterGenerator(λ=λ₁, opt=opt),
             "REVISE" => REVISEGenerator(λ=λ₁, opt=opt),
             "Schut" => GreedyGenerator(),
-            "ECCCo" => ECCCoGenerator(λ=Λ, opt=opt, use_class_loss=use_class_loss, nsamples=nsamples, nmin=nmin),
-            "ECCCo-Δ" => ECCCoGenerator(λ=Λ_Δ, opt=opt, use_class_loss=use_class_loss, use_energy_delta=true, nsamples=nsamples, nmin=nmin, reg_strength=reg_strength),
+            "ECCCo" => ECCCoGenerator(λ=Λ, opt=opt, use_class_loss=use_class_loss, nsamples=nsamples, nmin=nmin, niter=niter_eccco),
+            "ECCCo-Δ" => ECCCoGenerator(λ=Λ_Δ, opt=opt, use_class_loss=use_class_loss, use_energy_delta=true, nsamples=nsamples, nmin=nmin, niter=niter_eccco, reg_strength=reg_strength),
         )
     end
     return generator_dict
diff --git a/experiments/experiment.jl b/experiments/experiment.jl
index 46a81a1d..b161df03 100644
--- a/experiments/experiment.jl
+++ b/experiments/experiment.jl
@@ -37,6 +37,7 @@ Base.@kwdef struct Experiment
     loss::Function = Flux.Losses.crossentropy
     train_parallel::Bool = false
     reg_strength::Real = 0.1
+    niter_eccco::Union{Nothing,Int} = nothing
 end
 
 "A container to hold the results of an experiment."
diff --git a/experiments/gmsc.jl b/experiments/gmsc.jl
index 31e6ced5..74ff1bb3 100644
--- a/experiments/gmsc.jl
+++ b/experiments/gmsc.jl
@@ -27,6 +27,7 @@ run_experiment(
     n_individuals = n_ind,
     use_variants = false, 
     min_batch_size = 250,
-    nsamples = 10,
-    nmin = 10,
+    nsamples = 100,
+    nmin = 1,
+    niter_eccco = 30,
 )
\ No newline at end of file
diff --git a/src/generator.jl b/src/generator.jl
index 5a1117b8..b52fa275 100644
--- a/src/generator.jl
+++ b/src/generator.jl
@@ -11,12 +11,15 @@ function ECCCoGenerator(;
     use_energy_delta::Bool=false,
     nsamples::Union{Nothing,Int}=nothing,
     nmin::Union{Nothing,Int}=nothing,
-    reg_strength::Real=0.5,
+    niter::Union{Nothing,Int}=nothing,
+    reg_strength::Real=0.1,
     kwargs...
 )
 
+    # Default ECCCo parameters
     nsamples = isnothing(nsamples) ? 50 : nsamples
     nmin = isnothing(nmin) ? 25 : nmin
+    niter = isnothing(niter) ? 500 : niter
 
     # Default optimiser
     if isnothing(opt)
@@ -31,7 +34,7 @@ function ECCCoGenerator(;
     end
 
     _energy_penalty =
-        use_energy_delta ? (ECCCo.energy_delta, (n=nsamples, nmin=nmin, reg_strength=reg_strength)) : (ECCCo.distance_from_energy, (n=nsamples, nmin=nmin))
+        use_energy_delta ? (ECCCo.energy_delta, (n=nsamples, nmin=nmin, niter=niter, reg_strength=reg_strength)) : (ECCCo.distance_from_energy, (n=nsamples, nmin=nmin, niter=niter))
 
     _penalties = [
         (Objectives.distance_l1, []), 
diff --git a/src/penalties.jl b/src/penalties.jl
index 7c57a067..9df579b5 100644
--- a/src/penalties.jl
+++ b/src/penalties.jl
@@ -78,8 +78,7 @@ function energy_delta(
             _dict[:energy_sampler] = ECCCo.EnergySampler(ce; niter=niter, nsamples=n, kwargs...)
         end
         eng_sampler = _dict[:energy_sampler]
-        generate_samples!(eng_sampler, ce.num_counterfactuals, get_target_index(ce.data.y_levels, ce.target); niter=niter)
-        xsampled = eng_sampler.buffer[:,(end-ce.num_counterfactuals+1):end]
+        xsampled = rand(eng_sampler, ce.num_counterfactuals; from_buffer=from_buffer)
         push!(conditional_samples, xsampled)
     end
 
-- 
GitLab