Skip to content
Snippets Groups Projects
mnist.jl 2.42 KiB
# Training data:
dataname = "MNIST"
n_obs = 10000
counterfactual_data = load_mnist(n_obs)
counterfactual_data.X = ECCCo.pre_process.(counterfactual_data.X)
# Adjust domain constraints to account for noise added during pre-processing:
counterfactual_data.domain = fill(
    (minimum(counterfactual_data.X), maximum(counterfactual_data.X)),
    size(counterfactual_data.X, 1),
)

# VAE (trained on full dataset):
using CounterfactualExplanations.Models: load_mnist_vae
vae = load_mnist_vae()
counterfactual_data.generative_model = vae

# Test data:
test_data = load_mnist_test()

# Dimensionality reduction:
maxout_dim = vae.params.latent_dim
counterfactual_data.dt = MultivariateStats.fit(
    MultivariateStats.PCA,
    counterfactual_data.X;
    maxoutdim = maxout_dim,
);

# Model tuning:
model_tuning_params = DEFAULT_MODEL_TUNING_LARGE

# Tuning parameters:
tuning_params = DEFAULT_GENERATOR_TUNING
tuning_params = (; tuning_params..., Λ = [tuning_params.Λ[2:end]..., [0.01, 0.1, 3.0]])

# Additional models:
add_models = Dict("LeNet-5" => lenet5)

# CE measures (add cosine distance):
ce_measures =
    [CE_MEASURES..., ECCCo.distance_from_energy_ssim, ECCCo.distance_from_targets_ssim]

# Parameter choices:
params = (
    n_individuals = N_IND_SPECIFIED ? N_IND : 100,
    builder = default_builder(n_hidden = 128, n_layers = 1, activation = Flux.swish),
    𝒟x = Uniform(-1.0, 1.0),
    α = [1.0, 1.0, 1e-2],
    sampling_batch_size = 10,
    sampling_steps = 25,
    use_ensembling = true,
    use_variants = false,
    additional_models = add_models,
    epochs = 100,
    nsamples = 10,
    nmin = 1,
    niter_eccco = 10,
    Λ = [0.01, 0.25, 0.25],
    Λ_Δ = [0.01, 0.1, 0.3],
    opt = Flux.Optimise.Descent(0.1),
    reg_strength = 0.0,
    ce_measures = ce_measures,
    dim_reduction = true,
)

# Best grid search params:
params = append_best_params(params, dataname)

if GRID_SEARCH
    grid_search(
        counterfactual_data,
        test_data;
        dataname = dataname,
        tuning_params = tuning_params,
        params...,
    )
elseif FROM_GRID_SEARCH
    outcomes_file_path = joinpath(
        DEFAULT_OUTPUT_PATH,
        "grid_search",
        "$(replace(lowercase(dataname), " " => "_")).jls",
    )
    save_best(outcomes_file_path)
    bmk2csv(dataname)
else
    run_experiment(
        counterfactual_data,
        test_data;
        dataname = dataname,
        model_tuning_params = model_tuning_params,
        params...,
    )
end