From 7a2a2853738ff84df374a2f08136f54f5a59f0fd Mon Sep 17 00:00:00 2001
From: Pat Alt <55311242+pat-alt@users.noreply.github.com>
Date: Tue, 22 Aug 2023 14:32:51 +0200
Subject: [PATCH] slowly slowly

---
 experiments/benchmarking/benchmarking.jl      |  26 +--
 experiments/circles.jl                        |   2 +-
 experiments/data/data.jl                      |  33 ++-
 experiments/gmsc.jl                           |   2 +-
 experiments/linearly_separable.jl             |   2 +-
 experiments/mnist.jl                          |   2 +-
 experiments/models/models.jl                  |  36 +++-
 experiments/moons.jl                          |   2 +-
 experiments/post_processing.jl                |  73 +++++++
 experiments/setup.jl                          | 198 +++++-------------
 .../linearly_separable_model_performance.csv  |   3 +
 .../linearly_separable_model_performance.jls  | Bin 0 -> 325 bytes
 .../linearly_separable_generator_params.csv   |   2 +
 .../linearly_separable_model_params.csv       |   2 +
 14 files changed, 211 insertions(+), 172 deletions(-)
 create mode 100644 experiments/post_processing.jl
 create mode 100644 replicated/linearly_separable_model_performance.csv
 create mode 100644 replicated/linearly_separable_model_performance.jls
 create mode 100644 replicated/params/linearly_separable_generator_params.csv
 create mode 100644 replicated/params/linearly_separable_model_params.csv

diff --git a/experiments/benchmarking/benchmarking.jl b/experiments/benchmarking/benchmarking.jl
index 164e54c3..40d9457c 100644
--- a/experiments/benchmarking/benchmarking.jl
+++ b/experiments/benchmarking/benchmarking.jl
@@ -1,13 +1,3 @@
-"The default benchmarking measures."
-const default_measures = [
-    CounterfactualExplanations.distance,
-    ECCCo.distance_from_energy,
-    ECCCo.distance_from_targets,
-    CounterfactualExplanations.Evaluation.validity,
-    CounterfactualExplanations.Evaluation.redundancy,
-    ECCCo.set_size_penalty
-]
-
 function default_generators(
     Λ::AbstractArray=[0.25, 0.75, 0.75],
     Λ_Δ::AbstractArray=[Λ[1], Λ[2], 4.0],
@@ -52,14 +42,14 @@ end
 
 Run the benchmarking procedure.
 """
-function run_benchmark(;
-    n_individuals::Int,
-    dataname::String,
-    counterfactual_data::CounterfactualData,
-    model_dict::Dict,
-    generators::Union{Nothing, Dict}=nothing,
-    measures::AbstractArray=default_measures,   
-)
+function run_benchmark(exp::Experiment, model_dict::Dict)
+
+    n_individuals = exp.n_individuals
+    dataname = exp.dataname
+    counterfactual_data = exp.counterfactual_data
+    generators = exp.generators
+    measures = exp.ce_measures
+
     # Benchmark generators:
     if isnothing(generators)
         generator_dict = default_generators()
diff --git a/experiments/circles.jl b/experiments/circles.jl
index c8f9e6fd..75a9b9ca 100644
--- a/experiments/circles.jl
+++ b/experiments/circles.jl
@@ -1,6 +1,6 @@
 n_obs = Int(1000 / (1.0 - test_size))
 counterfactual_data, test_data = train_test_split(load_circles(n_obs; noise=0.05, factor=0.5); test_size=test_size)
-run_experiment(
+run_experiment!(
     counterfactual_data, test_data; dataname="Circles",
     n_hidden=32,
     α=[1.0, 1.0, 1e-2],
diff --git a/experiments/data/data.jl b/experiments/data/data.jl
index 47101aab..0facb7fa 100644
--- a/experiments/data/data.jl
+++ b/experiments/data/data.jl
@@ -1,15 +1,17 @@
-function prepare_data(
-    counterfactual_data::CounterfactualData;
-    𝒟x=Normal(),
-    min_batch_size=128,
-    sampling_batch_size=50,
-)
+function _prepare_data(exp::Experiment)
+
+    # Unpack data:
+    counterfactual_data = exp.counterfactual_data
+    min_batch_size = exp.min_batch_size
+    sampling_batch_size = exp.sampling_batch_size
+    𝒟x = exp.𝒟x
+
+    # Data parameters:
     X, _ = CounterfactualExplanations.DataPreprocessing.unpack_data(counterfactual_data)
     X = table(permutedims(X))
     labels = counterfactual_data.output_encoder.labels
     input_dim, n_obs = size(counterfactual_data.X)
     output_dim = length(unique(labels))
-    save_name = replace(lowercase(dataname), " " => "_")
 
     # Model parameters:
     batch_size = minimum([Int(round(n_obs / 10)), min_batch_size])
@@ -22,5 +24,20 @@ function prepare_data(
         input_size=(input_dim,),
         batch_size=sampling_batch_size,
     )
-    return X, labels, n_obs, save_name, batch_size, sampler
+    return X, labels, n_obs, batch_size, sampler
+end
+
+function meta_data(exp::Experiment)
+    _, _, n_obs, batch_size, _ = _prepare_data(exp::Experiment)
+    return n_obs, batch_size
+end
+
+function prepare_data(exp::Experiment)
+    X, labels, _, _,  sampler = _prepare_data(exp::Experiment)
+    return X, labels, sampler
+end
+
+function batch_size(exp::Experiment)
+    _, _, _, batch_size, _ = _prepare_data(exp::Experiment)
+    return batch_size
 end
\ No newline at end of file
diff --git a/experiments/gmsc.jl b/experiments/gmsc.jl
index 58d6b85b..f162f256 100644
--- a/experiments/gmsc.jl
+++ b/experiments/gmsc.jl
@@ -1,5 +1,5 @@
 counterfactual_data, test_data = train_test_split(load_gmsc(nothing); test_size=test_size)
-run_experiment(
+run_experiment!(
     counterfactual_data, test_data; dataname="GMSC",
     n_hidden=128,
     activation = Flux.swish,
diff --git a/experiments/linearly_separable.jl b/experiments/linearly_separable.jl
index 29a554b2..259731b2 100644
--- a/experiments/linearly_separable.jl
+++ b/experiments/linearly_separable.jl
@@ -3,4 +3,4 @@ counterfactual_data, test_data = train_test_split(
     load_blobs(n_obs; cluster_std=0.1, center_box=(-1.0 => 1.0));
     test_size=test_size
 )
-run_experiment(counterfactual_data, test_data; dataname="Linearly Separable")
\ No newline at end of file
+run_experiment!(counterfactual_data, test_data; dataname="Linearly Separable")
\ No newline at end of file
diff --git a/experiments/mnist.jl b/experiments/mnist.jl
index 39e13d40..88280bea 100644
--- a/experiments/mnist.jl
+++ b/experiments/mnist.jl
@@ -41,7 +41,7 @@ generator_dict = Dict(
 )
 
 # Run:
-run_experiment(
+run_experiment!(
     counterfactual_data, test_data; dataname="MNIST",
     n_hidden = 128,
     activation = Flux.swish,
diff --git a/experiments/models/models.jl b/experiments/models/models.jl
index 7adbaa0e..b3e3286d 100644
--- a/experiments/models/models.jl
+++ b/experiments/models/models.jl
@@ -1,3 +1,37 @@
 include("additional_models.jl")
 include("default_models.jl")
-include("train_models.jl")
\ No newline at end of file
+include("train_models.jl")
+
+function prepare_models(exp::Experiment)
+
+    # Unpack data:
+    X, labels, sampler = prepare_data(exp::Experiment)
+
+    # Setup:
+    if isnothing(exp.builder)
+        builder = default_builder()
+    end
+    if isnothing(exp.models)
+        @info "Using default models."
+        models = default_models(;
+            sampler=sampler,
+            builder=builder,
+            batch_size=batch_size(exp)
+        )
+    end
+
+    # Training:
+    if !pretrained
+        @info "Training models."
+        model_dict = train_models(models, X, labels; coverage=exp.coverage)
+    else
+        @info "Loading pre-trained models."
+        model_dict = Serialization.deserialize(joinpath(pretrained_path(), "$(exp.save_name)_models.jls"))
+    end
+
+    # Save models:
+    @info "Saving models to $(joinpath(exp.output_path, "$(exp.save_name)_models.jls"))."
+    Serialization.serialize(joinpath(exp.output_path, "$(exp.save_name)_models.jls"), model_dict)
+
+    return model_dict
+end
\ No newline at end of file
diff --git a/experiments/moons.jl b/experiments/moons.jl
index 01e124b1..ad492eb5 100644
--- a/experiments/moons.jl
+++ b/experiments/moons.jl
@@ -1,6 +1,6 @@
 n_obs = Int(2500 / (1.0 - test_size))
 counterfactual_data, test_data = train_test_split(load_moons(n_obs); test_size=test_size)
-run_experiment(
+run_experiment!(
     counterfactual_data, test_data; dataname="Moons",
     epochs=500,
     n_hidden=32,
diff --git a/experiments/post_processing.jl b/experiments/post_processing.jl
new file mode 100644
index 00000000..36c4b60f
--- /dev/null
+++ b/experiments/post_processing.jl
@@ -0,0 +1,73 @@
+"""
+    meta_model_performance(outcome::ExperimentOutcome; measure=MODEL_MEASURES)
+
+Compute and save the model performance for the models in `outcome.model_dict`.
+"""
+function meta_model_performance(outcome::ExperimentOutcome; measure=MODEL_MEASURES)
+
+    exp = outcome.exp
+    model_dict = outcome.model_dict
+
+    # Model performance:
+    model_performance = DataFrame()
+    for (mod_name, model) in model_dict
+        # Test performance:
+        _perf = CounterfactualExplanations.Models.model_evaluation(model, exp.test_data, measure=collect(values(measure)))
+        _perf = DataFrame([[p] for p in _perf], collect(keys(measure)))
+        _perf.mod_name .= mod_name
+        _perf.dataname .= dataname
+        model_performance = vcat(model_performance, _perf)
+    end
+    Serialization.serialize(joinpath(exp.output_path, "$(exp.save_name)_model_performance.jls"), model_performance)
+    CSV.write(joinpath(exp.output_path, "$(exp.save_name)_model_performance.csv"), model_performance)
+    @info "Model performance:"
+    println(model_performance)
+    return model_performance
+end
+
+"""
+    meta_data(exp::Experiment)
+
+Extract and save meta data about the experiment.
+"""
+function meta_data(outcome::ExperimentOutcome)
+
+    # Data params:
+    _, _, n_obs, default_save_name, batch_size, sampler = prepare_data(
+        exp.counterfactual_data;
+        𝒟x=exp.𝒟x,
+        sampling_batch_size=exp.sampling_batch_size
+    )
+    save_name = isnothing(save_name) ? default_save_name : save_name
+
+    params = DataFrame(
+        Dict(
+            :n_obs => Int.(round(n_obs / 10) * 10),
+            :epochs => epochs,
+            :batch_size => batch_size,
+            :n_hidden => n_hidden,
+            :n_layers => length(model_dict["MLP"].fitresult[1][1]) - 1,
+            :activation => string(activation),
+            :n_ens => n_ens,
+            :lambda => string(α[3]),
+            :jem_sampling_steps => jem.sampling_steps,
+            :sgld_batch_size => sampler.batch_size,
+            :dataname => dataname,
+        )
+    )
+    if !isnothing(save_path)
+        CSV.write(joinpath(save_path, "$(save_name)_model_params.csv"), params)
+    end
+
+    # Output:
+    opt = first(values(generator_dict)).opt
+    generator_params = DataFrame(
+        Dict(
+            :opt => string(typeof(opt)),
+            :eta => opt.eta,
+            :dataname => dataname,
+        )
+    )
+    CSV.write(joinpath(params_path, "$(save_name)_generator_params.csv"), generator_params)
+
+end
\ No newline at end of file
diff --git a/experiments/setup.jl b/experiments/setup.jl
index e50984a5..93834ab8 100644
--- a/experiments/setup.jl
+++ b/experiments/setup.jl
@@ -8,187 +8,105 @@ test_size = 0.2
 const DEFAULT_OUTPUT_PATH = "$(pwd())/results"
 const RETRAIN = "retrain" ∈ ARGS ? true : false
 
+"Default model performance measures."
+const MODEL_MEASURES = Dict(
+    :f1score => multiclass_f1score,
+    :acc => accuracy,
+    :precision => multiclass_precision
+)
+
+"Default coverage rate."
+const DEFAULT_COVERAGE = 0.95
+
+"The default benchmarking measures."
+const CE_MEASURES = [
+    CounterfactualExplanations.distance,
+    ECCCo.distance_from_energy,
+    ECCCo.distance_from_targets,
+    CounterfactualExplanations.Evaluation.validity,
+    CounterfactualExplanations.Evaluation.redundancy,
+    ECCCo.set_size_penalty
+]
+
 # Pre-trained models:
 function pretrained_path()
     @info "Models were pre-trained on `julia-1.8.5` and may not work on other versions."
     return joinpath(artifact"results-paper-submission-1.8.5", "results-paper-submission-1.8.5")
 end
 
-# Scripts:
-include("data/data.jl")
-include("models/models.jl")
-include("benchmarking/benchmarking.jl")
-
 "Sets up the experiment."
 Base.@kwdef struct Experiment
     counterfactual_data::CounterfactualData
     test_data::CounterfactualData
     dataname::String = "dataset"
+    save_name::String = replace(lowercase(dataname), " " => "_")
     output_path::String = DEFAULT_OUTPUT_PATH
     params_path::String = joinpath(output_path, "params")
     use_pretrained::Bool = true
-    models::Union{Nothing, Dict} = nothing
-    builder::Union{Nothing, MLJFlux.GenericBuilder} = nothing
+    models::Union{Nothing,Dict} = nothing
+    builder::Union{Nothing,MLJFlux.GenericBuilder} = nothing
     𝒟x::Distribution = Normal()
     sampling_batch_size::Int = 50
-    coverage::Float64 = 0.95
-    generators::Union{Nothing, Dict} = nothing
+    min_batch_size::Int = 128
+    coverage::Float64 = DEFAULT_COVERAGE
+    generators::Union{Nothing,Dict} = nothing
     n_individuals::Int = 50
+    ce_measures::AbstractArray = CE_MEASURES
+end
+
+"A container to hold the results of an experiment."
+mutable struct ExperimentOutcome
+    exp::Experiment
+    model_dict::Union{Nothing, Dict}
+    generator_dict::Union{Nothing, Dict}
+    bmk::Union{Nothing, Benchmark}
 end
 
+# Scripts:
+include("data/data.jl")
+include("models/models.jl")
+include("benchmarking/benchmarking.jl")
+include("post_processing.jl")
+
 """
-    run_experiment(exp::Experiment)
+    run_experiment!(exp::Experiment)
 
 Run the experiment specified by `exp`.
 """
-function run_experiment(exp::Experiment)
+function run_experiment!(exp::Experiment)
     
-    # SETUP ----------
+    # Setup
     @info "All results will be saved to $(exp.output_path)."
     isdir(exp.output_path) || mkdir(exp.output_path)
     @info "All parameter choices will be saved to $(exp.params_path)."
     isdir(exp.params_path) || mkdir(exp.params_path)
 
-    # Data
-    X, labels, n_obs, save_name, batch_size, sampler = prepare_data(
-        counterfactual_data;
-        𝒟x=exp.𝒟x,
-        sampling_batch_size=exp.sampling_batch_size,
-    )
-
-    # MODELS ----------
-    if isnothing(builder)
-        builder = default_builder()
-    end
-    if isnothing(models)
-        @info "Using default models."
-        models = default_models(;
-            sampler=sampler,
-            builder=builder,
-            batch_size=batch_size,
-        )
-    end
-
-    # TRAINING ----------
-    if !pretrained
-        @info "Training models."
-        model_dict = train_models(models, X, labels; coverage=coverage)
-        Serialization.serialize(joinpath(output_path, "$(save_name)_models.jls"), model_dict)
-    else
-        @info "Loading pre-trained models."
-        model_dict = Serialization.deserialize(joinpath(pretrained_path(), "$(save_name)_models.jls"))
-    end
-
-    params = DataFrame(
-        Dict(
-            :n_obs => Int.(round(n_obs/10)*10),
-            :epochs => epochs,
-            :batch_size => batch_size,
-            :n_hidden => n_hidden,
-            :n_layers => length(model_dict["MLP"].fitresult[1][1])-1,
-            :activation => string(activation),
-            :n_ens => n_ens,
-            :lambda => string(α[3]),
-            :jem_sampling_steps => jem.sampling_steps,
-            :sgld_batch_size => sampler.batch_size,
-            :dataname => dataname,
-        )
-    )
-    CSV.write(joinpath(params_path, "$(save_name)_model_params.csv"), params)
-
-    measure = Dict(
-        :f1score => multiclass_f1score,
-        :acc => accuracy,
-        :precision => multiclass_precision
-    )
-    model_performance = DataFrame()
-    for (mod_name, model) in model_dict
-        # Test performance:
-        _perf = CounterfactualExplanations.Models.model_evaluation(model, test_data, measure=collect(values(measure)))
-        _perf = DataFrame([[p] for p in _perf], collect(keys(measure)))
-        _perf.mod_name .= mod_name
-        _perf.dataname .= dataname
-        model_performance = vcat(model_performance, _perf)
-    end
-    Serialization.serialize(joinpath(output_path, "$(save_name)_model_performance.jls"), model_performance)
-    CSV.write(joinpath(output_path, "$(save_name)_model_performance.csv"), model_performance)
-    @info "Model performance:"
-    println(model_performance)
+    # Models
+    model_dict = prepare_models(exp)
+    outcome = ExperimentOutcome(exp, model_dict, nothing, nothing)
+    meta_model_performance(outcome)
     
-    # COUNTERFACTUALS ----------
-    # Benchmark generators:
-    bmk, generator_dict = run_benchmark(;
-        n_individuals=n_individuals,
-        dataname=dataname,
-        counterfactual_data=counterfactual_data,
-        model_dict=model_dict,
-        generators=generators,
-        measures=measures,
-    )
+    # Benchmark
+    bmk, generator_dict = run_benchmark(exp, model_dict)
+    outcome.bmk = bmk
+    outcome.generator_dict = generator_dict
 
-    # Output:
-    opt = first(values(generator_dict)).opt
-    generator_params = DataFrame(
-        Dict(
-            :opt => string(typeof(opt)),
-            :eta => opt.eta,
-            :dataname => dataname,
-        )
-    )
-    CSV.write(joinpath(params_path, "$(save_name)_generator_params.csv"), generator_params)
-    CSV.write(joinpath(output_path, "$(save_name)_benchmark.csv"), bmk())
+    # Save data:
+    Serialization.serialize(joinpath(exp.output_path, "$(exp.save_name)_outcome.jls"), outcome)
+    CSV.write(joinpath(exp.output_path, "$(exp.ave_name)_benchmark.csv"), bmk())
 
 end
 
 """
-    run_experiment(counterfactual_data::CounterfactualData, test_data::CounterfactualData; kwargs...)
+    run_experiment!(counterfactual_data::CounterfactualData, test_data::CounterfactualData; kwargs...)
 
 Overload the `run_experiment` function to allow for passing in `CounterfactualData` objects and other keyword arguments.
 """
-function run_experiment(counterfactual_data::CounterfactualData, test_data::CounterfactualData; kwargs...)
+function run_experiment!(counterfactual_data::CounterfactualData, test_data::CounterfactualData; kwargs...)
     # Parameters:
     exp = Experiment(
         counterfactual_data, test_data;
         kwargs...
     )
-    return run_experiment(exp)
-end
-
-"""
-    meta_data(exp::Experiment)
-
-Extract and save meta data about the experiment.
-"""
-function meta_data(
-    exp::Experiment; 
-    save_path::Union{String,Nothing}=nothing,
-    save_name::Union{String,Nothing}=nothing,
-)
-
-    # Data params:
-    _, _, n_obs, default_save_name, batch_size, sampler = prepare_data(
-        exp.counterfactual_data;
-        𝒟x=exp.𝒟x,
-        sampling_batch_size=exp.sampling_batch_size
-    )
-    save_name = isnothing(save_name) ? default_save_name : save_name
-
-    params = DataFrame(
-        Dict(
-            :n_obs => Int.(round(n_obs / 10) * 10),
-            :epochs => epochs,
-            :batch_size => batch_size,
-            :n_hidden => n_hidden,
-            :n_layers => length(model_dict["MLP"].fitresult[1][1]) - 1,
-            :activation => string(activation),
-            :n_ens => n_ens,
-            :lambda => string(α[3]),
-            :jem_sampling_steps => jem.sampling_steps,
-            :sgld_batch_size => sampler.batch_size,
-            :dataname => dataname,
-        )
-    )
-    if !isnothing(save_path)
-        CSV.write(joinpath(save_path, "$(save_name)_model_params.csv"), params)
-    end
+    return run_experiment!(exp)
 end
\ No newline at end of file
diff --git a/replicated/linearly_separable_model_performance.csv b/replicated/linearly_separable_model_performance.csv
new file mode 100644
index 00000000..b553ee64
--- /dev/null
+++ b/replicated/linearly_separable_model_performance.csv
@@ -0,0 +1,3 @@
+acc,precision,f1score,mod_name,dataname
+0.992,0.992,0.992,MLP,Linearly Separable
+0.992,0.9921259842519685,0.9919994879672299,JEM,Linearly Separable
diff --git a/replicated/linearly_separable_model_performance.jls b/replicated/linearly_separable_model_performance.jls
new file mode 100644
index 0000000000000000000000000000000000000000..2da7edaff6f46dceb83f9dfeb037d1a89ca2faa2
GIT binary patch
literal 325
zcmXr_@)2ZVU|=v6VB~a3EJ<`LO3Y1_=RV+OXD@kU+xC{Q+NTWpHY*vq;F85IqCyNZ
zN$zd}jJ%FX#U(|F$t8|OMTwR2ezRSefJW)CGCX)9%E0$9!g9-rvVHIEVKlPXhws1t
z*6n|9k1Tfm=eez=K=q;wiVqc;eSHEHnY~<nK|D)EA)m~=)Wo8kN`>Img2bZ4q?}Y8
zUIr5ZMpn<fl++4L7wa%HxBxxH;*yzMBJVd3>@o&M1`bxB3mKUclan7XauyV&CTA9B
z=I1?R<jBoWiO&Q2{V^j)3NRc%%qNWOX@<qg`9-OZL>U;LfOY8zGQdT3gcx*$8T|Zw
E0a8nCX8-^I

literal 0
HcmV?d00001

diff --git a/replicated/params/linearly_separable_generator_params.csv b/replicated/params/linearly_separable_generator_params.csv
new file mode 100644
index 00000000..ab1f0342
--- /dev/null
+++ b/replicated/params/linearly_separable_generator_params.csv
@@ -0,0 +1,2 @@
+dataname,eta,opt,λ1,λ2,λ3
+Linearly Separable,0.01,Descent,0.25,0.75,0.75
diff --git a/replicated/params/linearly_separable_model_params.csv b/replicated/params/linearly_separable_model_params.csv
new file mode 100644
index 00000000..0a58fca1
--- /dev/null
+++ b/replicated/params/linearly_separable_model_params.csv
@@ -0,0 +1,2 @@
+activation,batch_size,dataname,epochs,jem_sampling_steps,lambda,n_ens,n_hidden,n_layers,n_obs,sgld_batch_size
+swish,100,Linearly Separable,100,30,0.1,5,16,3,1000,50
-- 
GitLab