From aab1c928756c0c4f2e6f89f94f14ed1656da16aa Mon Sep 17 00:00:00 2001
From: Pat Alt <55311242+pat-alt@users.noreply.github.com>
Date: Sat, 23 Sep 2023 18:25:52 +0200
Subject: [PATCH] uhlala

---
 experiments/california_housing.jl | 15 ++++++++++++++-
 experiments/experiment.jl         |  3 ++-
 experiments/german_credit.jl      | 15 ++++++++++++++-
 experiments/gmsc.jl               | 15 ++++++++++++++-
 4 files changed, 44 insertions(+), 4 deletions(-)

diff --git a/experiments/california_housing.jl b/experiments/california_housing.jl
index fa965f21..fe6de6e6 100644
--- a/experiments/california_housing.jl
+++ b/experiments/california_housing.jl
@@ -2,6 +2,18 @@
 dataname = "California Housing"
 counterfactual_data, test_data = train_test_split(load_california_housing(nothing); test_size=TEST_SIZE)
 
+# VAE:
+using CounterfactualExplanations.GenerativeModels: VAE, train!
+X = counterfactual_data.X
+y = counterfactual_data.output_encoder.y
+vae = VAE(size(X, 1); nll=Flux.Losses.mse, epochs=100, λ=0.01, latent_dim=5)
+train!(vae, X, y)
+counterfactual_data.generative_model = vae
+
+# Dimensionality reduction:
+maxout_dim = vae.params.latent_dim
+counterfactual_data.dt = MultivariateStats.fit(MultivariateStats.PCA, counterfactual_data.X; maxoutdim=maxout_dim);
+
 # Model tuning:
 model_tuning_params = DEFAULT_MODEL_TUNING_LARGE
 
@@ -17,7 +29,8 @@ params = (
     sampling_batch_size=10,
     sampling_steps=30,
     use_ensembling=true,
-    opt=Flux.Optimise.Descent(0.05)
+    opt=Flux.Optimise.Descent(0.05),
+    dim_reduction=true,
 )
 
 # Best grid search params:
diff --git a/experiments/experiment.jl b/experiments/experiment.jl
index 4f84ce3a..c7f00302 100644
--- a/experiments/experiment.jl
+++ b/experiments/experiment.jl
@@ -94,7 +94,8 @@ function run_experiment(exper::Experiment; save_output::Bool=true, only_models::
 
     if FROM_GRID_SEARCH
         # Just load the best model from the grid search:
-        outcome = Serialization.deserialize(joinpath(exper.output_path, "grid_search", "$(exper.save_name)_best_eccco_delta.jls"))
+        outcomes = Serialization.deserialize(joinpath(exper.output_path, "grid_search", "$(exper.save_name).jls"))
+        outcome = best_absolute_outcome_eccco_Δ(outcomes)
     else
         # Run the experiment:
         outcome = ExperimentOutcome(exper, nothing, nothing, nothing)
diff --git a/experiments/german_credit.jl b/experiments/german_credit.jl
index 1d714825..7c81aeb1 100644
--- a/experiments/german_credit.jl
+++ b/experiments/german_credit.jl
@@ -2,6 +2,18 @@
 dataname = "German Credit"
 counterfactual_data, test_data = train_test_split(load_german_credit(nothing); test_size=TEST_SIZE)
 
+# VAE:
+using CounterfactualExplanations.GenerativeModels: VAE, train!
+X = counterfactual_data.X
+y = counterfactual_data.output_encoder.y
+vae = VAE(size(X,1); nll=Flux.Losses.mse, epochs=100, λ=0.01, latent_dim=5)
+train!(vae, X, y)
+counterfactual_data.generative_model = vae
+
+# Dimensionality reduction:
+maxout_dim = vae.params.latent_dim
+counterfactual_data.dt = MultivariateStats.fit(MultivariateStats.PCA, counterfactual_data.X; maxoutdim=maxout_dim);
+
 # Model tuning:
 model_tuning_params = DEFAULT_MODEL_TUNING_LARGE
 
@@ -17,7 +29,8 @@ params = (
     sampling_batch_size=10,
     sampling_steps=30,
     use_ensembling=true,
-    opt=Flux.Optimise.Descent(0.05)
+    opt=Flux.Optimise.Descent(0.05),
+    dim_reduction=true,
 )
 
 # Best grid search params:
diff --git a/experiments/gmsc.jl b/experiments/gmsc.jl
index 670ef2b1..8e5fe607 100644
--- a/experiments/gmsc.jl
+++ b/experiments/gmsc.jl
@@ -3,6 +3,18 @@ dataname = "GMSC"
 counterfactual_data, test_data = train_test_split(load_gmsc(nothing); test_size=TEST_SIZE)
 nobs = size(counterfactual_data.X, 2)
 
+# VAE:
+using CounterfactualExplanations.GenerativeModels: VAE, train!
+X = counterfactual_data.X
+y = counterfactual_data.output_encoder.y
+vae = VAE(size(X, 1); nll=Flux.Losses.mse, epochs=100, λ=0.01, latent_dim=5)
+train!(vae, X, y)
+counterfactual_data.generative_model = vae
+
+# Dimensionality reduction:
+maxout_dim = vae.params.latent_dim
+counterfactual_data.dt = MultivariateStats.fit(MultivariateStats.PCA, counterfactual_data.X; maxoutdim=maxout_dim);
+
 # Model tuning:
 model_tuning_params = DEFAULT_MODEL_TUNING_LARGE
 
@@ -18,7 +30,8 @@ params = (
     sampling_batch_size = 10,
     sampling_steps = 30,
     use_ensembling = true,
-    opt = Flux.Optimise.Descent(0.05)
+    opt=Flux.Optimise.Descent(0.05),
+    dim_reduction=true,
 )
 
 # Best grid search params:
-- 
GitLab