diff --git a/experiments/benchmarking/benchmarking.jl b/experiments/benchmarking/benchmarking.jl index 164e54c3997cdb6ad19b944b112e912c9c722bc2..40d9457cb0f72cae0f9b5d29ffdd2416d544e52f 100644 --- a/experiments/benchmarking/benchmarking.jl +++ b/experiments/benchmarking/benchmarking.jl @@ -1,13 +1,3 @@ -"The default benchmarking measures." -const default_measures = [ - CounterfactualExplanations.distance, - ECCCo.distance_from_energy, - ECCCo.distance_from_targets, - CounterfactualExplanations.Evaluation.validity, - CounterfactualExplanations.Evaluation.redundancy, - ECCCo.set_size_penalty -] - function default_generators( Λ::AbstractArray=[0.25, 0.75, 0.75], Λ_Δ::AbstractArray=[Λ[1], Λ[2], 4.0], @@ -52,14 +42,14 @@ end Run the benchmarking procedure. """ -function run_benchmark(; - n_individuals::Int, - dataname::String, - counterfactual_data::CounterfactualData, - model_dict::Dict, - generators::Union{Nothing, Dict}=nothing, - measures::AbstractArray=default_measures, -) +function run_benchmark(exp::Experiment, model_dict::Dict) + + n_individuals = exp.n_individuals + dataname = exp.dataname + counterfactual_data = exp.counterfactual_data + generators = exp.generators + measures = exp.ce_measures + # Benchmark generators: if isnothing(generators) generator_dict = default_generators() diff --git a/experiments/circles.jl b/experiments/circles.jl index c8f9e6fd458e367e8b0052a792ad6f7bd66d2e08..75a9b9caf9d8a7b7372d7108cf86840696d4bb33 100644 --- a/experiments/circles.jl +++ b/experiments/circles.jl @@ -1,6 +1,6 @@ n_obs = Int(1000 / (1.0 - test_size)) counterfactual_data, test_data = train_test_split(load_circles(n_obs; noise=0.05, factor=0.5); test_size=test_size) -run_experiment( +run_experiment!( counterfactual_data, test_data; dataname="Circles", n_hidden=32, α=[1.0, 1.0, 1e-2], diff --git a/experiments/data/data.jl b/experiments/data/data.jl index 47101aabaf85c319410e32ae23bbaa959de99c95..0facb7fa3c8f41318d9c2e18bfc192f9ce52cbbc 100644 --- a/experiments/data/data.jl +++ b/experiments/data/data.jl @@ -1,15 +1,17 @@ -function prepare_data( - counterfactual_data::CounterfactualData; - ð’Ÿx=Normal(), - min_batch_size=128, - sampling_batch_size=50, -) +function _prepare_data(exp::Experiment) + + # Unpack data: + counterfactual_data = exp.counterfactual_data + min_batch_size = exp.min_batch_size + sampling_batch_size = exp.sampling_batch_size + ð’Ÿx = exp.ð’Ÿx + + # Data parameters: X, _ = CounterfactualExplanations.DataPreprocessing.unpack_data(counterfactual_data) X = table(permutedims(X)) labels = counterfactual_data.output_encoder.labels input_dim, n_obs = size(counterfactual_data.X) output_dim = length(unique(labels)) - save_name = replace(lowercase(dataname), " " => "_") # Model parameters: batch_size = minimum([Int(round(n_obs / 10)), min_batch_size]) @@ -22,5 +24,20 @@ function prepare_data( input_size=(input_dim,), batch_size=sampling_batch_size, ) - return X, labels, n_obs, save_name, batch_size, sampler + return X, labels, n_obs, batch_size, sampler +end + +function meta_data(exp::Experiment) + _, _, n_obs, batch_size, _ = _prepare_data(exp::Experiment) + return n_obs, batch_size +end + +function prepare_data(exp::Experiment) + X, labels, _, _, sampler = _prepare_data(exp::Experiment) + return X, labels, sampler +end + +function batch_size(exp::Experiment) + _, _, _, batch_size, _ = _prepare_data(exp::Experiment) + return batch_size end \ No newline at end of file diff --git a/experiments/gmsc.jl b/experiments/gmsc.jl index 58d6b85b4d10150286ceda0f9e0c67db102923b8..f162f256aadd5a0069ccab063a70d071676dde7f 100644 --- a/experiments/gmsc.jl +++ b/experiments/gmsc.jl @@ -1,5 +1,5 @@ counterfactual_data, test_data = train_test_split(load_gmsc(nothing); test_size=test_size) -run_experiment( +run_experiment!( counterfactual_data, test_data; dataname="GMSC", n_hidden=128, activation = Flux.swish, diff --git a/experiments/linearly_separable.jl b/experiments/linearly_separable.jl index 29a554b2411d72c0b45d4b6e7d5f8352dc64395c..259731b28933208ac07a01f131eaef90f48e9cec 100644 --- a/experiments/linearly_separable.jl +++ b/experiments/linearly_separable.jl @@ -3,4 +3,4 @@ counterfactual_data, test_data = train_test_split( load_blobs(n_obs; cluster_std=0.1, center_box=(-1.0 => 1.0)); test_size=test_size ) -run_experiment(counterfactual_data, test_data; dataname="Linearly Separable") \ No newline at end of file +run_experiment!(counterfactual_data, test_data; dataname="Linearly Separable") \ No newline at end of file diff --git a/experiments/mnist.jl b/experiments/mnist.jl index 39e13d406a24c711eeb057292901d562119e4bd9..88280bea047505853f9b769d1e095d12871ab0dc 100644 --- a/experiments/mnist.jl +++ b/experiments/mnist.jl @@ -41,7 +41,7 @@ generator_dict = Dict( ) # Run: -run_experiment( +run_experiment!( counterfactual_data, test_data; dataname="MNIST", n_hidden = 128, activation = Flux.swish, diff --git a/experiments/models/models.jl b/experiments/models/models.jl index 7adbaa0e00a8f05bb47852800036948d19fcb2b4..b3e3286d18abdec4b5f3b6b3b4e512680002222c 100644 --- a/experiments/models/models.jl +++ b/experiments/models/models.jl @@ -1,3 +1,37 @@ include("additional_models.jl") include("default_models.jl") -include("train_models.jl") \ No newline at end of file +include("train_models.jl") + +function prepare_models(exp::Experiment) + + # Unpack data: + X, labels, sampler = prepare_data(exp::Experiment) + + # Setup: + if isnothing(exp.builder) + builder = default_builder() + end + if isnothing(exp.models) + @info "Using default models." + models = default_models(; + sampler=sampler, + builder=builder, + batch_size=batch_size(exp) + ) + end + + # Training: + if !pretrained + @info "Training models." + model_dict = train_models(models, X, labels; coverage=exp.coverage) + else + @info "Loading pre-trained models." + model_dict = Serialization.deserialize(joinpath(pretrained_path(), "$(exp.save_name)_models.jls")) + end + + # Save models: + @info "Saving models to $(joinpath(exp.output_path, "$(exp.save_name)_models.jls"))." + Serialization.serialize(joinpath(exp.output_path, "$(exp.save_name)_models.jls"), model_dict) + + return model_dict +end \ No newline at end of file diff --git a/experiments/moons.jl b/experiments/moons.jl index 01e124b1a6f574c362acb5bc508c04725528f9af..ad492eb56b37b1ffd3f2f22c8be6a473211ef060 100644 --- a/experiments/moons.jl +++ b/experiments/moons.jl @@ -1,6 +1,6 @@ n_obs = Int(2500 / (1.0 - test_size)) counterfactual_data, test_data = train_test_split(load_moons(n_obs); test_size=test_size) -run_experiment( +run_experiment!( counterfactual_data, test_data; dataname="Moons", epochs=500, n_hidden=32, diff --git a/experiments/post_processing.jl b/experiments/post_processing.jl new file mode 100644 index 0000000000000000000000000000000000000000..36c4b60f791a686563dd54f923bb8ef63431a26b --- /dev/null +++ b/experiments/post_processing.jl @@ -0,0 +1,73 @@ +""" + meta_model_performance(outcome::ExperimentOutcome; measure=MODEL_MEASURES) + +Compute and save the model performance for the models in `outcome.model_dict`. +""" +function meta_model_performance(outcome::ExperimentOutcome; measure=MODEL_MEASURES) + + exp = outcome.exp + model_dict = outcome.model_dict + + # Model performance: + model_performance = DataFrame() + for (mod_name, model) in model_dict + # Test performance: + _perf = CounterfactualExplanations.Models.model_evaluation(model, exp.test_data, measure=collect(values(measure))) + _perf = DataFrame([[p] for p in _perf], collect(keys(measure))) + _perf.mod_name .= mod_name + _perf.dataname .= dataname + model_performance = vcat(model_performance, _perf) + end + Serialization.serialize(joinpath(exp.output_path, "$(exp.save_name)_model_performance.jls"), model_performance) + CSV.write(joinpath(exp.output_path, "$(exp.save_name)_model_performance.csv"), model_performance) + @info "Model performance:" + println(model_performance) + return model_performance +end + +""" + meta_data(exp::Experiment) + +Extract and save meta data about the experiment. +""" +function meta_data(outcome::ExperimentOutcome) + + # Data params: + _, _, n_obs, default_save_name, batch_size, sampler = prepare_data( + exp.counterfactual_data; + ð’Ÿx=exp.ð’Ÿx, + sampling_batch_size=exp.sampling_batch_size + ) + save_name = isnothing(save_name) ? default_save_name : save_name + + params = DataFrame( + Dict( + :n_obs => Int.(round(n_obs / 10) * 10), + :epochs => epochs, + :batch_size => batch_size, + :n_hidden => n_hidden, + :n_layers => length(model_dict["MLP"].fitresult[1][1]) - 1, + :activation => string(activation), + :n_ens => n_ens, + :lambda => string(α[3]), + :jem_sampling_steps => jem.sampling_steps, + :sgld_batch_size => sampler.batch_size, + :dataname => dataname, + ) + ) + if !isnothing(save_path) + CSV.write(joinpath(save_path, "$(save_name)_model_params.csv"), params) + end + + # Output: + opt = first(values(generator_dict)).opt + generator_params = DataFrame( + Dict( + :opt => string(typeof(opt)), + :eta => opt.eta, + :dataname => dataname, + ) + ) + CSV.write(joinpath(params_path, "$(save_name)_generator_params.csv"), generator_params) + +end \ No newline at end of file diff --git a/experiments/setup.jl b/experiments/setup.jl index e50984a50e23caf534859c064f5e2f69bd24e51c..93834ab85ae2f4f13657913d6aaa2cc1a87de9e7 100644 --- a/experiments/setup.jl +++ b/experiments/setup.jl @@ -8,187 +8,105 @@ test_size = 0.2 const DEFAULT_OUTPUT_PATH = "$(pwd())/results" const RETRAIN = "retrain" ∈ ARGS ? true : false +"Default model performance measures." +const MODEL_MEASURES = Dict( + :f1score => multiclass_f1score, + :acc => accuracy, + :precision => multiclass_precision +) + +"Default coverage rate." +const DEFAULT_COVERAGE = 0.95 + +"The default benchmarking measures." +const CE_MEASURES = [ + CounterfactualExplanations.distance, + ECCCo.distance_from_energy, + ECCCo.distance_from_targets, + CounterfactualExplanations.Evaluation.validity, + CounterfactualExplanations.Evaluation.redundancy, + ECCCo.set_size_penalty +] + # Pre-trained models: function pretrained_path() @info "Models were pre-trained on `julia-1.8.5` and may not work on other versions." return joinpath(artifact"results-paper-submission-1.8.5", "results-paper-submission-1.8.5") end -# Scripts: -include("data/data.jl") -include("models/models.jl") -include("benchmarking/benchmarking.jl") - "Sets up the experiment." Base.@kwdef struct Experiment counterfactual_data::CounterfactualData test_data::CounterfactualData dataname::String = "dataset" + save_name::String = replace(lowercase(dataname), " " => "_") output_path::String = DEFAULT_OUTPUT_PATH params_path::String = joinpath(output_path, "params") use_pretrained::Bool = true - models::Union{Nothing, Dict} = nothing - builder::Union{Nothing, MLJFlux.GenericBuilder} = nothing + models::Union{Nothing,Dict} = nothing + builder::Union{Nothing,MLJFlux.GenericBuilder} = nothing ð’Ÿx::Distribution = Normal() sampling_batch_size::Int = 50 - coverage::Float64 = 0.95 - generators::Union{Nothing, Dict} = nothing + min_batch_size::Int = 128 + coverage::Float64 = DEFAULT_COVERAGE + generators::Union{Nothing,Dict} = nothing n_individuals::Int = 50 + ce_measures::AbstractArray = CE_MEASURES +end + +"A container to hold the results of an experiment." +mutable struct ExperimentOutcome + exp::Experiment + model_dict::Union{Nothing, Dict} + generator_dict::Union{Nothing, Dict} + bmk::Union{Nothing, Benchmark} end +# Scripts: +include("data/data.jl") +include("models/models.jl") +include("benchmarking/benchmarking.jl") +include("post_processing.jl") + """ - run_experiment(exp::Experiment) + run_experiment!(exp::Experiment) Run the experiment specified by `exp`. """ -function run_experiment(exp::Experiment) +function run_experiment!(exp::Experiment) - # SETUP ---------- + # Setup @info "All results will be saved to $(exp.output_path)." isdir(exp.output_path) || mkdir(exp.output_path) @info "All parameter choices will be saved to $(exp.params_path)." isdir(exp.params_path) || mkdir(exp.params_path) - # Data - X, labels, n_obs, save_name, batch_size, sampler = prepare_data( - counterfactual_data; - ð’Ÿx=exp.ð’Ÿx, - sampling_batch_size=exp.sampling_batch_size, - ) - - # MODELS ---------- - if isnothing(builder) - builder = default_builder() - end - if isnothing(models) - @info "Using default models." - models = default_models(; - sampler=sampler, - builder=builder, - batch_size=batch_size, - ) - end - - # TRAINING ---------- - if !pretrained - @info "Training models." - model_dict = train_models(models, X, labels; coverage=coverage) - Serialization.serialize(joinpath(output_path, "$(save_name)_models.jls"), model_dict) - else - @info "Loading pre-trained models." - model_dict = Serialization.deserialize(joinpath(pretrained_path(), "$(save_name)_models.jls")) - end - - params = DataFrame( - Dict( - :n_obs => Int.(round(n_obs/10)*10), - :epochs => epochs, - :batch_size => batch_size, - :n_hidden => n_hidden, - :n_layers => length(model_dict["MLP"].fitresult[1][1])-1, - :activation => string(activation), - :n_ens => n_ens, - :lambda => string(α[3]), - :jem_sampling_steps => jem.sampling_steps, - :sgld_batch_size => sampler.batch_size, - :dataname => dataname, - ) - ) - CSV.write(joinpath(params_path, "$(save_name)_model_params.csv"), params) - - measure = Dict( - :f1score => multiclass_f1score, - :acc => accuracy, - :precision => multiclass_precision - ) - model_performance = DataFrame() - for (mod_name, model) in model_dict - # Test performance: - _perf = CounterfactualExplanations.Models.model_evaluation(model, test_data, measure=collect(values(measure))) - _perf = DataFrame([[p] for p in _perf], collect(keys(measure))) - _perf.mod_name .= mod_name - _perf.dataname .= dataname - model_performance = vcat(model_performance, _perf) - end - Serialization.serialize(joinpath(output_path, "$(save_name)_model_performance.jls"), model_performance) - CSV.write(joinpath(output_path, "$(save_name)_model_performance.csv"), model_performance) - @info "Model performance:" - println(model_performance) + # Models + model_dict = prepare_models(exp) + outcome = ExperimentOutcome(exp, model_dict, nothing, nothing) + meta_model_performance(outcome) - # COUNTERFACTUALS ---------- - # Benchmark generators: - bmk, generator_dict = run_benchmark(; - n_individuals=n_individuals, - dataname=dataname, - counterfactual_data=counterfactual_data, - model_dict=model_dict, - generators=generators, - measures=measures, - ) + # Benchmark + bmk, generator_dict = run_benchmark(exp, model_dict) + outcome.bmk = bmk + outcome.generator_dict = generator_dict - # Output: - opt = first(values(generator_dict)).opt - generator_params = DataFrame( - Dict( - :opt => string(typeof(opt)), - :eta => opt.eta, - :dataname => dataname, - ) - ) - CSV.write(joinpath(params_path, "$(save_name)_generator_params.csv"), generator_params) - CSV.write(joinpath(output_path, "$(save_name)_benchmark.csv"), bmk()) + # Save data: + Serialization.serialize(joinpath(exp.output_path, "$(exp.save_name)_outcome.jls"), outcome) + CSV.write(joinpath(exp.output_path, "$(exp.ave_name)_benchmark.csv"), bmk()) end """ - run_experiment(counterfactual_data::CounterfactualData, test_data::CounterfactualData; kwargs...) + run_experiment!(counterfactual_data::CounterfactualData, test_data::CounterfactualData; kwargs...) Overload the `run_experiment` function to allow for passing in `CounterfactualData` objects and other keyword arguments. """ -function run_experiment(counterfactual_data::CounterfactualData, test_data::CounterfactualData; kwargs...) +function run_experiment!(counterfactual_data::CounterfactualData, test_data::CounterfactualData; kwargs...) # Parameters: exp = Experiment( counterfactual_data, test_data; kwargs... ) - return run_experiment(exp) -end - -""" - meta_data(exp::Experiment) - -Extract and save meta data about the experiment. -""" -function meta_data( - exp::Experiment; - save_path::Union{String,Nothing}=nothing, - save_name::Union{String,Nothing}=nothing, -) - - # Data params: - _, _, n_obs, default_save_name, batch_size, sampler = prepare_data( - exp.counterfactual_data; - ð’Ÿx=exp.ð’Ÿx, - sampling_batch_size=exp.sampling_batch_size - ) - save_name = isnothing(save_name) ? default_save_name : save_name - - params = DataFrame( - Dict( - :n_obs => Int.(round(n_obs / 10) * 10), - :epochs => epochs, - :batch_size => batch_size, - :n_hidden => n_hidden, - :n_layers => length(model_dict["MLP"].fitresult[1][1]) - 1, - :activation => string(activation), - :n_ens => n_ens, - :lambda => string(α[3]), - :jem_sampling_steps => jem.sampling_steps, - :sgld_batch_size => sampler.batch_size, - :dataname => dataname, - ) - ) - if !isnothing(save_path) - CSV.write(joinpath(save_path, "$(save_name)_model_params.csv"), params) - end + return run_experiment!(exp) end \ No newline at end of file diff --git a/replicated/linearly_separable_model_performance.csv b/replicated/linearly_separable_model_performance.csv new file mode 100644 index 0000000000000000000000000000000000000000..b553ee647749944dee7df1fafed939a9231f9659 --- /dev/null +++ b/replicated/linearly_separable_model_performance.csv @@ -0,0 +1,3 @@ +acc,precision,f1score,mod_name,dataname +0.992,0.992,0.992,MLP,Linearly Separable +0.992,0.9921259842519685,0.9919994879672299,JEM,Linearly Separable diff --git a/replicated/linearly_separable_model_performance.jls b/replicated/linearly_separable_model_performance.jls new file mode 100644 index 0000000000000000000000000000000000000000..2da7edaff6f46dceb83f9dfeb037d1a89ca2faa2 Binary files /dev/null and b/replicated/linearly_separable_model_performance.jls differ diff --git a/replicated/params/linearly_separable_generator_params.csv b/replicated/params/linearly_separable_generator_params.csv new file mode 100644 index 0000000000000000000000000000000000000000..ab1f034255091905f433476b9ba04befd95b9082 --- /dev/null +++ b/replicated/params/linearly_separable_generator_params.csv @@ -0,0 +1,2 @@ +dataname,eta,opt,λ1,λ2,λ3 +Linearly Separable,0.01,Descent,0.25,0.75,0.75 diff --git a/replicated/params/linearly_separable_model_params.csv b/replicated/params/linearly_separable_model_params.csv new file mode 100644 index 0000000000000000000000000000000000000000..0a58fca15b7af3f809adca966e053ecae2c9a335 --- /dev/null +++ b/replicated/params/linearly_separable_model_params.csv @@ -0,0 +1,2 @@ +activation,batch_size,dataname,epochs,jem_sampling_steps,lambda,n_ens,n_hidden,n_layers,n_obs,sgld_batch_size +swish,100,Linearly Separable,100,30,0.1,5,16,3,1000,50