From 89f13dafe7303429b75b81cc50e755a05149830d Mon Sep 17 00:00:00 2001
From: Pat Alt <55311242+pat-alt@users.noreply.github.com>
Date: Fri, 22 Sep 2023 11:25:32 +0200
Subject: [PATCH] fashion mnist

---
 experiments/fmnist.jl | 24 ++++++++++++++++++------
 experiments/mnist.jl  |  8 ++++----
 2 files changed, 22 insertions(+), 10 deletions(-)

diff --git a/experiments/fmnist.jl b/experiments/fmnist.jl
index bb170d4d..42083e74 100644
--- a/experiments/fmnist.jl
+++ b/experiments/fmnist.jl
@@ -3,6 +3,8 @@ dataname = "Fashion MNIST"
 n_obs = 10000
 counterfactual_data = load_fashion_mnist(n_obs)
 counterfactual_data.X = ECCCo.pre_process.(counterfactual_data.X)
+# Adjust domain constraints to account for noise added during pre-processing:
+counterfactual_data.domain = fill((minimum(counterfactual_data.X), maximum(counterfactual_data.X)), size(counterfactual_data.X, 1))
 
 # VAE (trained on full dataset):
 using CounterfactualExplanations.Models: load_fashion_mnist_vae
@@ -12,6 +14,10 @@ counterfactual_data.generative_model = vae
 # Test data:
 test_data = load_fashion_mnist_test()
 
+# Dimensionality reduction:
+maxout_dim = vae.params.latent_dim
+counterfactual_data.dt = MultivariateStats.fit(MultivariateStats.PCA, counterfactual_data.X; maxoutdim=maxout_dim);
+
 # Model tuning:
 model_tuning_params = DEFAULT_MODEL_TUNING_LARGE
 
@@ -24,9 +30,12 @@ add_models = Dict(
     "LeNet-5" => lenet5,
 )
 
+# CE measures (add cosine distance):
+ce_measures = [CE_MEASURES..., ECCCo.distance_from_energy_ssim, ECCCo.distance_from_targets_ssim]
+
 # Parameter choices:
 params = (
-    n_individuals=N_IND_SPECIFIED ? N_IND : 50,
+    n_individuals=N_IND_SPECIFIED ? N_IND : 2,
     builder=default_builder(n_hidden=128, n_layers=1, activation=Flux.swish),
     𝒟x=Uniform(-1.0, 1.0),
     α=[1.0, 1.0, 1e-2],
@@ -35,13 +44,16 @@ params = (
     use_ensembling=true,
     use_variants=false,
     additional_models=add_models,
-    epochs=10,
-    nsamples=50,
+    epochs=100,
+    nsamples=10,
     nmin=1,
     niter_eccco=10,
-    Λ=[0.1, 0.25, 0.25],
-    Λ_Δ=[0.1, 0.1, 2.5],
-    opt=Flux.Optimise.Descent(0.1)
+    Λ=[0.01, 0.25, 0.25],
+    Λ_Δ=[0.01, 0.1, 0.3],
+    opt=Flux.Optimise.Descent(0.1),
+    reg_strength=0.0,
+    ce_measures=ce_measures,
+    dim_reduction=true,
 )
 
 if !GRID_SEARCH
diff --git a/experiments/mnist.jl b/experiments/mnist.jl
index 07c6d11e..a6e58e92 100644
--- a/experiments/mnist.jl
+++ b/experiments/mnist.jl
@@ -11,13 +11,13 @@ using CounterfactualExplanations.Models: load_mnist_vae
 vae = load_mnist_vae()
 counterfactual_data.generative_model = vae
 
+# Test data:
+test_data = load_mnist_test()
+
 # Dimensionality reduction:
 maxout_dim = vae.params.latent_dim
 counterfactual_data.dt = MultivariateStats.fit(MultivariateStats.PCA, counterfactual_data.X; maxoutdim=maxout_dim);
 
-# Test data:
-test_data = load_mnist_test()
-
 # Model tuning:
 model_tuning_params = DEFAULT_MODEL_TUNING_LARGE
 
@@ -35,7 +35,7 @@ ce_measures = [CE_MEASURES..., ECCCo.distance_from_energy_ssim, ECCCo.distance_f
 
 # Parameter choices:
 params = (
-    n_individuals=N_IND_SPECIFIED ? N_IND : 2,
+    n_individuals=N_IND_SPECIFIED ? N_IND : 100,
     builder=default_builder(n_hidden=128, n_layers=1, activation=Flux.swish),
     𝒟x=Uniform(-1.0, 1.0),
     α=[1.0, 1.0, 1e-2],
-- 
GitLab