diff --git a/experiments/fmnist.jl b/experiments/fmnist.jl index bb170d4d0a4df86b69d045aab44f86e6ee358cff..42083e7404988c4f5afaadc490439ad2e3b77e2b 100644 --- a/experiments/fmnist.jl +++ b/experiments/fmnist.jl @@ -3,6 +3,8 @@ dataname = "Fashion MNIST" n_obs = 10000 counterfactual_data = load_fashion_mnist(n_obs) counterfactual_data.X = ECCCo.pre_process.(counterfactual_data.X) +# Adjust domain constraints to account for noise added during pre-processing: +counterfactual_data.domain = fill((minimum(counterfactual_data.X), maximum(counterfactual_data.X)), size(counterfactual_data.X, 1)) # VAE (trained on full dataset): using CounterfactualExplanations.Models: load_fashion_mnist_vae @@ -12,6 +14,10 @@ counterfactual_data.generative_model = vae # Test data: test_data = load_fashion_mnist_test() +# Dimensionality reduction: +maxout_dim = vae.params.latent_dim +counterfactual_data.dt = MultivariateStats.fit(MultivariateStats.PCA, counterfactual_data.X; maxoutdim=maxout_dim); + # Model tuning: model_tuning_params = DEFAULT_MODEL_TUNING_LARGE @@ -24,9 +30,12 @@ add_models = Dict( "LeNet-5" => lenet5, ) +# CE measures (add cosine distance): +ce_measures = [CE_MEASURES..., ECCCo.distance_from_energy_ssim, ECCCo.distance_from_targets_ssim] + # Parameter choices: params = ( - n_individuals=N_IND_SPECIFIED ? N_IND : 50, + n_individuals=N_IND_SPECIFIED ? N_IND : 2, builder=default_builder(n_hidden=128, n_layers=1, activation=Flux.swish), ð’Ÿx=Uniform(-1.0, 1.0), α=[1.0, 1.0, 1e-2], @@ -35,13 +44,16 @@ params = ( use_ensembling=true, use_variants=false, additional_models=add_models, - epochs=10, - nsamples=50, + epochs=100, + nsamples=10, nmin=1, niter_eccco=10, - Λ=[0.1, 0.25, 0.25], - Λ_Δ=[0.1, 0.1, 2.5], - opt=Flux.Optimise.Descent(0.1) + Λ=[0.01, 0.25, 0.25], + Λ_Δ=[0.01, 0.1, 0.3], + opt=Flux.Optimise.Descent(0.1), + reg_strength=0.0, + ce_measures=ce_measures, + dim_reduction=true, ) if !GRID_SEARCH diff --git a/experiments/mnist.jl b/experiments/mnist.jl index 07c6d11e4775381cf128b727138511c092220e2d..a6e58e92d0298d3d9d9c4b5088f578a47928863c 100644 --- a/experiments/mnist.jl +++ b/experiments/mnist.jl @@ -11,13 +11,13 @@ using CounterfactualExplanations.Models: load_mnist_vae vae = load_mnist_vae() counterfactual_data.generative_model = vae +# Test data: +test_data = load_mnist_test() + # Dimensionality reduction: maxout_dim = vae.params.latent_dim counterfactual_data.dt = MultivariateStats.fit(MultivariateStats.PCA, counterfactual_data.X; maxoutdim=maxout_dim); -# Test data: -test_data = load_mnist_test() - # Model tuning: model_tuning_params = DEFAULT_MODEL_TUNING_LARGE @@ -35,7 +35,7 @@ ce_measures = [CE_MEASURES..., ECCCo.distance_from_energy_ssim, ECCCo.distance_f # Parameter choices: params = ( - n_individuals=N_IND_SPECIFIED ? N_IND : 2, + n_individuals=N_IND_SPECIFIED ? N_IND : 100, builder=default_builder(n_hidden=128, n_layers=1, activation=Flux.swish), ð’Ÿx=Uniform(-1.0, 1.0), α=[1.0, 1.0, 1e-2],