diff --git a/experiments/california_housing.jl b/experiments/california_housing.jl index fa965f21c0621e510f4c082ebd149f756c36552c..fe6de6e60e28255ab0dea548e2e80759c4949b0e 100644 --- a/experiments/california_housing.jl +++ b/experiments/california_housing.jl @@ -2,6 +2,18 @@ dataname = "California Housing" counterfactual_data, test_data = train_test_split(load_california_housing(nothing); test_size=TEST_SIZE) +# VAE: +using CounterfactualExplanations.GenerativeModels: VAE, train! +X = counterfactual_data.X +y = counterfactual_data.output_encoder.y +vae = VAE(size(X, 1); nll=Flux.Losses.mse, epochs=100, λ=0.01, latent_dim=5) +train!(vae, X, y) +counterfactual_data.generative_model = vae + +# Dimensionality reduction: +maxout_dim = vae.params.latent_dim +counterfactual_data.dt = MultivariateStats.fit(MultivariateStats.PCA, counterfactual_data.X; maxoutdim=maxout_dim); + # Model tuning: model_tuning_params = DEFAULT_MODEL_TUNING_LARGE @@ -17,7 +29,8 @@ params = ( sampling_batch_size=10, sampling_steps=30, use_ensembling=true, - opt=Flux.Optimise.Descent(0.05) + opt=Flux.Optimise.Descent(0.05), + dim_reduction=true, ) # Best grid search params: diff --git a/experiments/experiment.jl b/experiments/experiment.jl index 4f84ce3a356109ba8295fef546dd2ff56eec0c33..c7f00302f43e1deb2f3873c106b54c803d405046 100644 --- a/experiments/experiment.jl +++ b/experiments/experiment.jl @@ -94,7 +94,8 @@ function run_experiment(exper::Experiment; save_output::Bool=true, only_models:: if FROM_GRID_SEARCH # Just load the best model from the grid search: - outcome = Serialization.deserialize(joinpath(exper.output_path, "grid_search", "$(exper.save_name)_best_eccco_delta.jls")) + outcomes = Serialization.deserialize(joinpath(exper.output_path, "grid_search", "$(exper.save_name).jls")) + outcome = best_absolute_outcome_eccco_Δ(outcomes) else # Run the experiment: outcome = ExperimentOutcome(exper, nothing, nothing, nothing) diff --git a/experiments/german_credit.jl b/experiments/german_credit.jl index 1d714825c6255e46b35085506a5acb6c23b9c513..7c81aeb1e7a60e780f49a39b7521115c257214aa 100644 --- a/experiments/german_credit.jl +++ b/experiments/german_credit.jl @@ -2,6 +2,18 @@ dataname = "German Credit" counterfactual_data, test_data = train_test_split(load_german_credit(nothing); test_size=TEST_SIZE) +# VAE: +using CounterfactualExplanations.GenerativeModels: VAE, train! +X = counterfactual_data.X +y = counterfactual_data.output_encoder.y +vae = VAE(size(X,1); nll=Flux.Losses.mse, epochs=100, λ=0.01, latent_dim=5) +train!(vae, X, y) +counterfactual_data.generative_model = vae + +# Dimensionality reduction: +maxout_dim = vae.params.latent_dim +counterfactual_data.dt = MultivariateStats.fit(MultivariateStats.PCA, counterfactual_data.X; maxoutdim=maxout_dim); + # Model tuning: model_tuning_params = DEFAULT_MODEL_TUNING_LARGE @@ -17,7 +29,8 @@ params = ( sampling_batch_size=10, sampling_steps=30, use_ensembling=true, - opt=Flux.Optimise.Descent(0.05) + opt=Flux.Optimise.Descent(0.05), + dim_reduction=true, ) # Best grid search params: diff --git a/experiments/gmsc.jl b/experiments/gmsc.jl index 670ef2b16672ba9af669e1d0745d666cc029da94..8e5fe607ef181b1e17c6701aecd5c26ab7a73ab8 100644 --- a/experiments/gmsc.jl +++ b/experiments/gmsc.jl @@ -3,6 +3,18 @@ dataname = "GMSC" counterfactual_data, test_data = train_test_split(load_gmsc(nothing); test_size=TEST_SIZE) nobs = size(counterfactual_data.X, 2) +# VAE: +using CounterfactualExplanations.GenerativeModels: VAE, train! +X = counterfactual_data.X +y = counterfactual_data.output_encoder.y +vae = VAE(size(X, 1); nll=Flux.Losses.mse, epochs=100, λ=0.01, latent_dim=5) +train!(vae, X, y) +counterfactual_data.generative_model = vae + +# Dimensionality reduction: +maxout_dim = vae.params.latent_dim +counterfactual_data.dt = MultivariateStats.fit(MultivariateStats.PCA, counterfactual_data.X; maxoutdim=maxout_dim); + # Model tuning: model_tuning_params = DEFAULT_MODEL_TUNING_LARGE @@ -18,7 +30,8 @@ params = ( sampling_batch_size = 10, sampling_steps = 30, use_ensembling = true, - opt = Flux.Optimise.Descent(0.05) + opt=Flux.Optimise.Descent(0.05), + dim_reduction=true, ) # Best grid search params: