From 9fafe55b33d7c603bfa68c60497a891d3dcc05e7 Mon Sep 17 00:00:00 2001
From: Pat Alt <55311242+pat-alt@users.noreply.github.com>
Date: Tue, 22 Aug 2023 17:29:18 +0200
Subject: [PATCH] ok

---
 experiments/Manifest.toml            |  2 +-
 experiments/Project.toml             |  3 +++
 experiments/circles.jl               |  2 +-
 experiments/experiment.jl            |  7 ++++---
 experiments/gmsc.jl                  |  2 +-
 experiments/linearly_separable.jl    |  2 +-
 experiments/models/default_models.jl |  2 +-
 experiments/models/models.jl         | 28 +++++++++++++---------------
 experiments/moons.jl                 |  2 +-
 experiments/post_processing.jl       |  2 +-
 experiments/setup_env.jl             |  9 ++++++---
 11 files changed, 33 insertions(+), 28 deletions(-)

diff --git a/experiments/Manifest.toml b/experiments/Manifest.toml
index 70510d5a..934a3d38 100644
--- a/experiments/Manifest.toml
+++ b/experiments/Manifest.toml
@@ -2,7 +2,7 @@
 
 julia_version = "1.9.2"
 manifest_format = "2.0"
-project_hash = "a981b871c7220e241da2ed077ce7f07968e32677"
+project_hash = "59bc4ed7aabf01af8804eba2ac9b602604180c6c"
 
 [[deps.AbstractFFTs]]
 deps = ["LinearAlgebra"]
diff --git a/experiments/Project.toml b/experiments/Project.toml
index 0c27e4aa..c99239a4 100644
--- a/experiments/Project.toml
+++ b/experiments/Project.toml
@@ -1,11 +1,14 @@
 [deps]
 CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b"
 CounterfactualExplanations = "2f13d31b-18db-44c1-bc43-ebaf2cff0be0"
+DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
 Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
 ECCCo = "0232c203-4013-4b0d-ad96-43e3e11ac3bf"
+Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
 JointEnergyModels = "48c56d24-211d-4463-bbc0-7a701b291131"
 LazyArtifacts = "4af54fe1-eca0-43a8-85a7-787d91b784e3"
 MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
 MLJEnsembles = "50ed68f4-41fd-4504-931a-ed422449fee0"
 MLJFlux = "094fc8d1-fd35-5302-93ea-dabda2abf845"
 MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
+Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
diff --git a/experiments/circles.jl b/experiments/circles.jl
index 0872ba15..f67c83a0 100644
--- a/experiments/circles.jl
+++ b/experiments/circles.jl
@@ -1,5 +1,5 @@
 n_obs = Int(1000 / (1.0 - TEST_SIZE))
-counterfactual_data, test_data = train_test_split(load_circles(n_obs; noise=0.05, factor=0.5); TEST_SIZE=TEST_SIZE)
+counterfactual_data, test_data = train_test_split(load_circles(n_obs; noise=0.05, factor=0.5); test_size=TEST_SIZE)
 run_experiment!(
     counterfactual_data, test_data; dataname="Circles",
     n_hidden=32,
diff --git a/experiments/experiment.jl b/experiments/experiment.jl
index 3394be91..41559715 100644
--- a/experiments/experiment.jl
+++ b/experiments/experiment.jl
@@ -8,7 +8,7 @@ Base.@kwdef struct Experiment
     params_path::String = joinpath(output_path, "params")
     use_pretrained::Bool = true
     models::Union{Nothing,Dict} = nothing
-    builder::Union{Nothing,MLJFlux.GenericBuilder} = nothing
+    builder::Union{Nothing,MLJFlux.Builder} = nothing
     𝒟x::Distribution = Normal()
     sampling_batch_size::Int = 50
     min_batch_size::Int = 128
@@ -69,8 +69,9 @@ Overload the `run_experiment` function to allow for passing in `CounterfactualDa
 """
 function run_experiment!(counterfactual_data::CounterfactualData, test_data::CounterfactualData; kwargs...)
     # Parameters:
-    exp = Experiment(
-        counterfactual_data, test_data;
+    exp = Experiment(;
+        counterfactual_data=counterfactual_data,
+        test_data=test_data,
         kwargs...
     )
     return run_experiment!(exp)
diff --git a/experiments/gmsc.jl b/experiments/gmsc.jl
index baf5e580..62575fd5 100644
--- a/experiments/gmsc.jl
+++ b/experiments/gmsc.jl
@@ -1,4 +1,4 @@
-counterfactual_data, test_data = train_test_split(load_gmsc(nothing); TEST_SIZE=TEST_SIZE)
+counterfactual_data, test_data = train_test_split(load_gmsc(nothing); test_size=TEST_SIZE)
 run_experiment!(
     counterfactual_data, test_data; dataname="GMSC",
     n_hidden=128,
diff --git a/experiments/linearly_separable.jl b/experiments/linearly_separable.jl
index c95a0509..f884bd20 100644
--- a/experiments/linearly_separable.jl
+++ b/experiments/linearly_separable.jl
@@ -1,6 +1,6 @@
 n_obs = Int(1000 / (1.0 - TEST_SIZE))
 counterfactual_data, test_data = train_test_split(
     load_blobs(n_obs; cluster_std=0.1, center_box=(-1.0 => 1.0));
-    TEST_SIZE=TEST_SIZE
+    test_size=TEST_SIZE
 )
 run_experiment!(counterfactual_data, test_data; dataname="Linearly Separable")
\ No newline at end of file
diff --git a/experiments/models/default_models.jl b/experiments/models/default_models.jl
index 4b87ebce..e57e926f 100644
--- a/experiments/models/default_models.jl
+++ b/experiments/models/default_models.jl
@@ -30,7 +30,7 @@ Builds a dictionary of default models for training.
 """
 function default_models(;
     sampler::AbstractSampler,
-    builder::MLJFlux.GenericBuilder=default_builder(),
+    builder::MLJFlux.Builder=default_builder(),
     epochs::Int=25,
     batch_size::Int=128,
     finaliser::Function=Flux.softmax,
diff --git a/experiments/models/models.jl b/experiments/models/models.jl
index b3e3286d..f9d645bf 100644
--- a/experiments/models/models.jl
+++ b/experiments/models/models.jl
@@ -7,26 +7,24 @@ function prepare_models(exp::Experiment)
     # Unpack data:
     X, labels, sampler = prepare_data(exp::Experiment)
 
-    # Setup:
-    if isnothing(exp.builder)
-        builder = default_builder()
-    end
-    if isnothing(exp.models)
-        @info "Using default models."
-        models = default_models(;
-            sampler=sampler,
-            builder=builder,
-            batch_size=batch_size(exp)
-        )
-    end
-
     # Training:
-    if !pretrained
+    if !exp.use_pretrained
+        if isnothing(exp.builder)
+            builder = default_builder()
+        end
+        if isnothing(exp.models)
+            @info "Using default models."
+            models = default_models(;
+                sampler=sampler,
+                builder=builder,
+                batch_size=batch_size(exp)
+            )
+        end
         @info "Training models."
         model_dict = train_models(models, X, labels; coverage=exp.coverage)
     else
         @info "Loading pre-trained models."
-        model_dict = Serialization.deserialize(joinpath(pretrained_path(), "$(exp.save_name)_models.jls"))
+        model_dict = Serialization.deserialize(joinpath(pretrained_path(), "results/$(exp.save_name)_models.jls"))
     end
 
     # Save models:
diff --git a/experiments/moons.jl b/experiments/moons.jl
index 198a605e..76aa72c5 100644
--- a/experiments/moons.jl
+++ b/experiments/moons.jl
@@ -1,5 +1,5 @@
 n_obs = Int(2500 / (1.0 - TEST_SIZE))
-counterfactual_data, test_data = train_test_split(load_moons(n_obs); TEST_SIZE=TEST_SIZE)
+counterfactual_data, test_data = train_test_split(load_moons(n_obs); test_size=TEST_SIZE)
 run_experiment!(
     counterfactual_data, test_data; dataname="Moons",
     epochs=500,
diff --git a/experiments/post_processing.jl b/experiments/post_processing.jl
index 58d17b82..144b84b3 100644
--- a/experiments/post_processing.jl
+++ b/experiments/post_processing.jl
@@ -88,7 +88,7 @@ function meta_model_performance(outcome::ExperimentOutcome; measures::Union{Noth
         _perf = CounterfactualExplanations.Models.model_evaluation(model, exp.test_data, measure=collect(values(measures)))
         _perf = DataFrame([[p] for p in _perf], collect(keys(measures)))
         _perf.mod_name .= mod_name
-        _perf.dataname .= dataname
+        _perf.dataname .= exp.dataname
         model_performance = vcat(model_performance, _perf)
     end
     
diff --git a/experiments/setup_env.jl b/experiments/setup_env.jl
index 1fd73899..f2f84e44 100644
--- a/experiments/setup_env.jl
+++ b/experiments/setup_env.jl
@@ -10,20 +10,23 @@ using CounterfactualExplanations.Generators: JSMADescent
 using CounterfactualExplanations.Models: load_mnist_mlp, load_fashion_mnist_mlp, train, probs
 using CounterfactualExplanations.Objectives
 using CSV
-using Distributions
+using DataFrames
+using Distributions: Normal, Distribution, Categorical
 using ECCCo
+using Flux
 using JointEnergyModels
 using LazyArtifacts
-using MLJBase: multiclass_f1score, accuracy, multiclass_precision
+using MLJBase: multiclass_f1score, accuracy, multiclass_precision, table
 using MLJEnsembles
 using MLJFlux
+using Serialization
 
 # Constants:
 const LATEST_VERSION = "1.8.5"
 const ARTIFACT_NAME = "results-paper-submission-$(LATEST_VERSION)"
 artifact_toml = LazyArtifacts.find_artifacts_toml(".")
 _hash = artifact_hash(ARTIFACT_NAME, artifact_toml)
-const LATEST_ARTIFACT_PATH = artifact_path(_hash)
+const LATEST_ARTIFACT_PATH = joinpath(artifact_path(_hash), ARTIFACT_NAME)
 
 # Pre-trained models:
 function pretrained_path()
-- 
GitLab