diff --git a/README.md b/README.md index 4968380cc8928148e7ad9e27d69e9f48ff4de8f6..81e3af3426ee82b513a951032eabde92a0eee7e6 100644 --- a/README.md +++ b/README.md @@ -48,6 +48,7 @@ We use the following identifiers: - `moons` (*Moons* data) - `circles` (*Circles* data) - `mnist` (*MNIST* data) +- `fmnist` (*Fashion MNIST* data) - `gmsc` (*GMSC* data) To run experiments for multiple datasets at once simply separate them with a comma `,` diff --git a/experiments/fmnist.jl b/experiments/fmnist.jl new file mode 100644 index 0000000000000000000000000000000000000000..41d8194dfb1e63dc99f2c4bc4cacdc862d597c2f --- /dev/null +++ b/experiments/fmnist.jl @@ -0,0 +1,44 @@ +# Training data: +n_obs = 10000 +counterfactual_data = load_fashion_mnist(n_obs) +counterfactual_data.X = ECCCo.pre_process.(counterfactual_data.X) + +# VAE (trained on full dataset): +using CounterfactualExplanations.Models: load_fashion_mnist_vae +vae = load_fashion_mnist_vae() +counterfactual_data.generative_model = vae + +# Test data: +test_data = load_fashion_mnist_test() + +# Additional models: +add_models = Dict( + "LeNet-5" => lenet5, +) + +# Default builder: +n_hidden = 128 +activation = Flux.swish +builder = MLJFlux.@builder Flux.Chain( + Dense(n_in, n_hidden, activation), + Dense(n_hidden, n_out), +) + +# Run: +run_experiment( + counterfactual_data, test_data; + dataname="Fashion-MNIST", + builder= builder, + ð’Ÿx = Uniform(-1.0, 1.0), + α = [1.0,1.0,1e-2], + sampling_batch_size = 10, + sampling_steps=50, + use_ensembling = true, + n_individuals = 5, + nsamples = 10, + nmin = 10, + use_variants = false, + use_class_loss = true, + additional_models=add_models, + epochs = 10, +) \ No newline at end of file diff --git a/experiments/hpc_blue.sh b/experiments/hpc_blue.sh index 25c71cd8929f46db060d826ae8a77aaf4f528abe..504a8c2e989dd0a38cbcb31234d400c8cbc73155 100644 --- a/experiments/hpc_blue.sh +++ b/experiments/hpc_blue.sh @@ -11,4 +11,4 @@ module load 2023r1 openmpi julia -srun julia --project=experiments experiments/run_experiments.jl -- data=linearly_separable threaded +srun julia --project=experiments experiments/run_experiments.jl -- data=linearly_separable threaded > experiments/hpc.log diff --git a/experiments/run_experiments.jl b/experiments/run_experiments.jl index 331106d4bb768af69b128b25afb5a766ee9322ae..82f2e4543984d284a04df46d634292cdc8358ac1 100644 --- a/experiments/run_experiments.jl +++ b/experiments/run_experiments.jl @@ -38,6 +38,11 @@ if "mnist" in datanames include("mnist.jl") end +if "fmnist" in datanames + @info "Running Fashion-MNIST experiment." + include("fmnist.jl") +end + if USE_MPI MPI.Finalize() end