From 42c125ed108282301208df7d68a448f3b8c7b665 Mon Sep 17 00:00:00 2001 From: Pat Alt <55311242+pat-alt@users.noreply.github.com> Date: Fri, 8 Sep 2023 15:45:32 +0200 Subject: [PATCH] uh --- README.md | 1 + experiments/fmnist.jl | 44 ++++++++++++++++++++++++++++++++++ experiments/hpc_blue.sh | 2 +- experiments/run_experiments.jl | 5 ++++ 4 files changed, 51 insertions(+), 1 deletion(-) create mode 100644 experiments/fmnist.jl diff --git a/README.md b/README.md index 4968380c..81e3af34 100644 --- a/README.md +++ b/README.md @@ -48,6 +48,7 @@ We use the following identifiers: - `moons` (*Moons* data) - `circles` (*Circles* data) - `mnist` (*MNIST* data) +- `fmnist` (*Fashion MNIST* data) - `gmsc` (*GMSC* data) To run experiments for multiple datasets at once simply separate them with a comma `,` diff --git a/experiments/fmnist.jl b/experiments/fmnist.jl new file mode 100644 index 00000000..41d8194d --- /dev/null +++ b/experiments/fmnist.jl @@ -0,0 +1,44 @@ +# Training data: +n_obs = 10000 +counterfactual_data = load_fashion_mnist(n_obs) +counterfactual_data.X = ECCCo.pre_process.(counterfactual_data.X) + +# VAE (trained on full dataset): +using CounterfactualExplanations.Models: load_fashion_mnist_vae +vae = load_fashion_mnist_vae() +counterfactual_data.generative_model = vae + +# Test data: +test_data = load_fashion_mnist_test() + +# Additional models: +add_models = Dict( + "LeNet-5" => lenet5, +) + +# Default builder: +n_hidden = 128 +activation = Flux.swish +builder = MLJFlux.@builder Flux.Chain( + Dense(n_in, n_hidden, activation), + Dense(n_hidden, n_out), +) + +# Run: +run_experiment( + counterfactual_data, test_data; + dataname="Fashion-MNIST", + builder= builder, + ð’Ÿx = Uniform(-1.0, 1.0), + α = [1.0,1.0,1e-2], + sampling_batch_size = 10, + sampling_steps=50, + use_ensembling = true, + n_individuals = 5, + nsamples = 10, + nmin = 10, + use_variants = false, + use_class_loss = true, + additional_models=add_models, + epochs = 10, +) \ No newline at end of file diff --git a/experiments/hpc_blue.sh b/experiments/hpc_blue.sh index 25c71cd8..504a8c2e 100644 --- a/experiments/hpc_blue.sh +++ b/experiments/hpc_blue.sh @@ -11,4 +11,4 @@ module load 2023r1 openmpi julia -srun julia --project=experiments experiments/run_experiments.jl -- data=linearly_separable threaded +srun julia --project=experiments experiments/run_experiments.jl -- data=linearly_separable threaded > experiments/hpc.log diff --git a/experiments/run_experiments.jl b/experiments/run_experiments.jl index 331106d4..82f2e454 100644 --- a/experiments/run_experiments.jl +++ b/experiments/run_experiments.jl @@ -38,6 +38,11 @@ if "mnist" in datanames include("mnist.jl") end +if "fmnist" in datanames + @info "Running Fashion-MNIST experiment." + include("fmnist.jl") +end + if USE_MPI MPI.Finalize() end -- GitLab