From 89f13dafe7303429b75b81cc50e755a05149830d Mon Sep 17 00:00:00 2001 From: Pat Alt <55311242+pat-alt@users.noreply.github.com> Date: Fri, 22 Sep 2023 11:25:32 +0200 Subject: [PATCH] fashion mnist --- experiments/fmnist.jl | 24 ++++++++++++++++++------ experiments/mnist.jl | 8 ++++---- 2 files changed, 22 insertions(+), 10 deletions(-) diff --git a/experiments/fmnist.jl b/experiments/fmnist.jl index bb170d4d..42083e74 100644 --- a/experiments/fmnist.jl +++ b/experiments/fmnist.jl @@ -3,6 +3,8 @@ dataname = "Fashion MNIST" n_obs = 10000 counterfactual_data = load_fashion_mnist(n_obs) counterfactual_data.X = ECCCo.pre_process.(counterfactual_data.X) +# Adjust domain constraints to account for noise added during pre-processing: +counterfactual_data.domain = fill((minimum(counterfactual_data.X), maximum(counterfactual_data.X)), size(counterfactual_data.X, 1)) # VAE (trained on full dataset): using CounterfactualExplanations.Models: load_fashion_mnist_vae @@ -12,6 +14,10 @@ counterfactual_data.generative_model = vae # Test data: test_data = load_fashion_mnist_test() +# Dimensionality reduction: +maxout_dim = vae.params.latent_dim +counterfactual_data.dt = MultivariateStats.fit(MultivariateStats.PCA, counterfactual_data.X; maxoutdim=maxout_dim); + # Model tuning: model_tuning_params = DEFAULT_MODEL_TUNING_LARGE @@ -24,9 +30,12 @@ add_models = Dict( "LeNet-5" => lenet5, ) +# CE measures (add cosine distance): +ce_measures = [CE_MEASURES..., ECCCo.distance_from_energy_ssim, ECCCo.distance_from_targets_ssim] + # Parameter choices: params = ( - n_individuals=N_IND_SPECIFIED ? N_IND : 50, + n_individuals=N_IND_SPECIFIED ? N_IND : 2, builder=default_builder(n_hidden=128, n_layers=1, activation=Flux.swish), ð’Ÿx=Uniform(-1.0, 1.0), α=[1.0, 1.0, 1e-2], @@ -35,13 +44,16 @@ params = ( use_ensembling=true, use_variants=false, additional_models=add_models, - epochs=10, - nsamples=50, + epochs=100, + nsamples=10, nmin=1, niter_eccco=10, - Λ=[0.1, 0.25, 0.25], - Λ_Δ=[0.1, 0.1, 2.5], - opt=Flux.Optimise.Descent(0.1) + Λ=[0.01, 0.25, 0.25], + Λ_Δ=[0.01, 0.1, 0.3], + opt=Flux.Optimise.Descent(0.1), + reg_strength=0.0, + ce_measures=ce_measures, + dim_reduction=true, ) if !GRID_SEARCH diff --git a/experiments/mnist.jl b/experiments/mnist.jl index 07c6d11e..a6e58e92 100644 --- a/experiments/mnist.jl +++ b/experiments/mnist.jl @@ -11,13 +11,13 @@ using CounterfactualExplanations.Models: load_mnist_vae vae = load_mnist_vae() counterfactual_data.generative_model = vae +# Test data: +test_data = load_mnist_test() + # Dimensionality reduction: maxout_dim = vae.params.latent_dim counterfactual_data.dt = MultivariateStats.fit(MultivariateStats.PCA, counterfactual_data.X; maxoutdim=maxout_dim); -# Test data: -test_data = load_mnist_test() - # Model tuning: model_tuning_params = DEFAULT_MODEL_TUNING_LARGE @@ -35,7 +35,7 @@ ce_measures = [CE_MEASURES..., ECCCo.distance_from_energy_ssim, ECCCo.distance_f # Parameter choices: params = ( - n_individuals=N_IND_SPECIFIED ? N_IND : 2, + n_individuals=N_IND_SPECIFIED ? N_IND : 100, builder=default_builder(n_hidden=128, n_layers=1, activation=Flux.swish), ð’Ÿx=Uniform(-1.0, 1.0), α=[1.0, 1.0, 1e-2], -- GitLab