From ed4003e0170a055c00dd03faa339df9647742df2 Mon Sep 17 00:00:00 2001
From: Pat Alt <55311242+pat-alt@users.noreply.github.com>
Date: Thu, 21 Sep 2023 16:56:34 +0200
Subject: [PATCH] omg

---
 experiments/benchmarking/benchmarking.jl | 2 +-
 experiments/mnist.jl                     | 8 ++++----
 2 files changed, 5 insertions(+), 5 deletions(-)

diff --git a/experiments/benchmarking/benchmarking.jl b/experiments/benchmarking/benchmarking.jl
index 05bf7977..6e5ca946 100644
--- a/experiments/benchmarking/benchmarking.jl
+++ b/experiments/benchmarking/benchmarking.jl
@@ -18,7 +18,7 @@ function default_generators(;
         generator_dict = Dict(
             "Wachter" => WachterGenerator(λ=λ₁, opt=opt),
             "REVISE" => REVISEGenerator(λ=λ₁, opt=opt),
-            "Schut" => GreedyGenerator(),
+            "Schut" => GreedyGenerator(η=opt.eta),
             "ECCCo" => ECCCoGenerator(λ=Λ, opt=opt, use_class_loss=use_class_loss, nsamples=nsamples, nmin=nmin, niter=niter_eccco),
             "ECCCo (no CP)" => ECCCoGenerator(λ=[λ₁, 0.0, λ₃], opt=opt, use_class_loss=use_class_loss, nsamples=nsamples, nmin=nmin, niter=niter_eccco),
             "ECCCo (no EBM)" => ECCCoGenerator(λ=[λ₁, λ₂, 0.0], opt=opt, use_class_loss=use_class_loss, nsamples=nsamples, nmin=nmin, niter=niter_eccco),
diff --git a/experiments/mnist.jl b/experiments/mnist.jl
index 4ceef766..9f7ab1bd 100644
--- a/experiments/mnist.jl
+++ b/experiments/mnist.jl
@@ -31,7 +31,7 @@ ce_measures = [CE_MEASURES..., ECCCo.distance_from_energy_ssim, ECCCo.distance_f
 
 # Parameter choices:
 params = (
-    n_individuals=N_IND_SPECIFIED ? N_IND : 5,
+    n_individuals=N_IND_SPECIFIED ? N_IND : 10,
     builder=default_builder(n_hidden=128, n_layers=1, activation=Flux.swish),
     𝒟x=Uniform(-1.0, 1.0),
     α=[1.0, 1.0, 1e-2],
@@ -44,9 +44,9 @@ params = (
     nsamples=10,
     nmin=1,
     niter_eccco=10,
-    Λ=[0.1, 0.25, 0.25],
-    Λ_Δ=[0.1, 0.1, 1.0],
-    opt=Flux.Optimise.Descent(0.25),
+    Λ=[0.01, 0.25, 0.25],
+    Λ_Δ=[0.01, 0.1, 1.0],
+    opt=Flux.Optimise.Descent(0.1),
     reg_strength = 0.01,
     ce_measures=ce_measures,
 )
-- 
GitLab