From 804c844276267674d9dbbd9064913f096fbf0137 Mon Sep 17 00:00:00 2001
From: Pat Alt <55311242+pat-alt@users.noreply.github.com>
Date: Thu, 31 Aug 2023 08:12:16 +0200
Subject: [PATCH] added parallelization option

---
 experiments/benchmarking/benchmarking.jl | 14 +++++++++++---
 experiments/experiment.jl                |  6 ++++++
 experiments/setup_env.jl                 |  1 +
 3 files changed, 18 insertions(+), 3 deletions(-)

diff --git a/experiments/benchmarking/benchmarking.jl b/experiments/benchmarking/benchmarking.jl
index f4176e58..4a27f5e2 100644
--- a/experiments/benchmarking/benchmarking.jl
+++ b/experiments/benchmarking/benchmarking.jl
@@ -1,4 +1,4 @@
-function default_generators(
+function default_generators(;
     Λ::AbstractArray=[0.25, 0.75, 0.75],
     Λ_Δ::AbstractArray=Λ,
     use_variants::Bool=true,
@@ -49,10 +49,17 @@ function run_benchmark(exp::Experiment, model_dict::Dict)
     counterfactual_data = exp.counterfactual_data
     generator_dict = exp.generators
     measures = exp.ce_measures
+    parallelizer = exp.parallelizer
 
     # Benchmark generators:
     if isnothing(generator_dict)
-        generator_dict = default_generators()
+        generator_dict = default_generators(;
+            Λ=exp.Λ,
+            Λ_Δ=exp.Λ_Δ,
+            use_variants=exp.use_variants,
+            use_class_loss=exp.use_class_loss,
+            opt=exp.opt
+        )
     end
 
     # Run benchmark:
@@ -72,7 +79,8 @@ function run_benchmark(exp::Experiment, model_dict::Dict)
                 n_individuals=n_individuals,
                 target=target, factual=factual,
                 initialization=:identity,
-                converge_when=:generator_conditions
+                converge_when=:generator_conditions,
+                parallelizer=parallelizer,
             )
             push!(bmks, bmk)
         end
diff --git a/experiments/experiment.jl b/experiments/experiment.jl
index 9c6f7f3f..f3770aa4 100644
--- a/experiments/experiment.jl
+++ b/experiments/experiment.jl
@@ -17,6 +17,12 @@ Base.@kwdef struct Experiment
     n_individuals::Int = 50
     ce_measures::AbstractArray = CE_MEASURES
     model_measures::Dict = MODEL_MEASURES
+    use_class_loss::Bool = true
+    use_variants::Bool = true
+    Λ::AbstractArray = [0.25, 0.75, 0.75]
+    Λ_Δ::AbstractArray = Λ
+    opt::Flux.Optimise.Optimizer = Flux.Optimise.Descent(0.01)
+    parallelizer::Union{Nothing, AbstractParallelizer} = nothing
 end
 
 "A container to hold the results of an experiment."
diff --git a/experiments/setup_env.jl b/experiments/setup_env.jl
index 1ee2e24c..564bc50a 100644
--- a/experiments/setup_env.jl
+++ b/experiments/setup_env.jl
@@ -11,6 +11,7 @@ using CounterfactualExplanations.Evaluation: benchmark, evaluate, Benchmark
 using CounterfactualExplanations.Generators: JSMADescent
 using CounterfactualExplanations.Models: load_mnist_mlp, load_fashion_mnist_mlp, train, probs
 using CounterfactualExplanations.Objectives
+using CounterfactualExplanations.Parallelization
 using CSV
 using DataFrames
 using Distributions: Normal, Distribution, Categorical
-- 
GitLab