From 42c125ed108282301208df7d68a448f3b8c7b665 Mon Sep 17 00:00:00 2001
From: Pat Alt <55311242+pat-alt@users.noreply.github.com>
Date: Fri, 8 Sep 2023 15:45:32 +0200
Subject: [PATCH] uh

---
 README.md                      |  1 +
 experiments/fmnist.jl          | 44 ++++++++++++++++++++++++++++++++++
 experiments/hpc_blue.sh        |  2 +-
 experiments/run_experiments.jl |  5 ++++
 4 files changed, 51 insertions(+), 1 deletion(-)
 create mode 100644 experiments/fmnist.jl

diff --git a/README.md b/README.md
index 4968380c..81e3af34 100644
--- a/README.md
+++ b/README.md
@@ -48,6 +48,7 @@ We use the following identifiers:
 - `moons` (*Moons* data)
 - `circles` (*Circles* data)
 - `mnist` (*MNIST* data)
+- `fmnist` (*Fashion MNIST* data)
 - `gmsc` (*GMSC* data)
 
 To run experiments for multiple datasets at once simply separate them with a comma `,`
diff --git a/experiments/fmnist.jl b/experiments/fmnist.jl
new file mode 100644
index 00000000..41d8194d
--- /dev/null
+++ b/experiments/fmnist.jl
@@ -0,0 +1,44 @@
+# Training data:
+n_obs = 10000
+counterfactual_data = load_fashion_mnist(n_obs)
+counterfactual_data.X = ECCCo.pre_process.(counterfactual_data.X)
+
+# VAE (trained on full dataset):
+using CounterfactualExplanations.Models: load_fashion_mnist_vae
+vae = load_fashion_mnist_vae()
+counterfactual_data.generative_model = vae
+
+# Test data:
+test_data = load_fashion_mnist_test()
+
+# Additional models:
+add_models = Dict(
+    "LeNet-5" => lenet5,
+)
+
+# Default builder:
+n_hidden = 128
+activation = Flux.swish
+builder = MLJFlux.@builder Flux.Chain(
+    Dense(n_in, n_hidden, activation),
+    Dense(n_hidden, n_out),
+)
+
+# Run:
+run_experiment(
+    counterfactual_data, test_data; 
+    dataname="Fashion-MNIST",
+    builder= builder,
+    𝒟x = Uniform(-1.0, 1.0),
+    α = [1.0,1.0,1e-2],
+    sampling_batch_size = 10,
+    sampling_steps=50,
+    use_ensembling = true,
+    n_individuals = 5,
+    nsamples = 10,
+    nmin = 10,
+    use_variants = false,
+    use_class_loss = true,
+    additional_models=add_models,
+    epochs = 10,
+)
\ No newline at end of file
diff --git a/experiments/hpc_blue.sh b/experiments/hpc_blue.sh
index 25c71cd8..504a8c2e 100644
--- a/experiments/hpc_blue.sh
+++ b/experiments/hpc_blue.sh
@@ -11,4 +11,4 @@
 
 module load 2023r1 openmpi julia
 
-srun julia --project=experiments experiments/run_experiments.jl -- data=linearly_separable threaded
+srun julia --project=experiments experiments/run_experiments.jl -- data=linearly_separable threaded > experiments/hpc.log
diff --git a/experiments/run_experiments.jl b/experiments/run_experiments.jl
index 331106d4..82f2e454 100644
--- a/experiments/run_experiments.jl
+++ b/experiments/run_experiments.jl
@@ -38,6 +38,11 @@ if "mnist" in datanames
     include("mnist.jl")
 end
 
+if "fmnist" in datanames
+    @info "Running Fashion-MNIST experiment."
+    include("fmnist.jl")
+end
+
 if USE_MPI
     MPI.Finalize()
 end
-- 
GitLab