# Training data: dataname = "Fashion MNIST" n_obs = 10000 counterfactual_data = load_fashion_mnist(n_obs) counterfactual_data.X = ECCCo.pre_process.(counterfactual_data.X) # VAE (trained on full dataset): using CounterfactualExplanations.Models: load_fashion_mnist_vae vae = load_fashion_mnist_vae() counterfactual_data.generative_model = vae # Test data: test_data = load_fashion_mnist_test() # Model tuning: model_tuning_params = DEFAULT_MODEL_TUNING_LARGE # Tuning parameters: tuning_params = DEFAULT_GENERATOR_TUNING tuning_params = (; tuning_params..., Λ=[tuning_params.Λ[2:end]..., [0.1, 0.1, 3.0]]) # Additional models: add_models = Dict( "LeNet-5" => lenet5, ) # Parameter choices: params = ( n_individuals=N_IND_SPECIFIED ? N_IND : 50, builder=default_builder(n_hidden=128, n_layers=1, activation=Flux.swish), 𝒟x=Uniform(-1.0, 1.0), α=[1.0, 1.0, 1e-2], sampling_batch_size=10, sampling_steps=25, use_ensembling=true, use_variants=false, additional_models=add_models, epochs=10, nsamples=50, nmin=1, niter_eccco=10, Λ=[0.1, 0.25, 0.25], Λ_Δ=[0.1, 0.1, 2.5], ) if !GRID_SEARCH run_experiment( counterfactual_data, test_data; dataname=dataname, params... ) else grid_search( counterfactual_data, test_data; dataname=dataname, tuning_params=tuning_params, n_individuals=5 ) end