diff --git a/experiments/experiment.jl b/experiments/experiment.jl index 469a123ee9f69e20f403476d629f1e1647bb8df2..5faf62f9ed3449a9682fca5db36264d39d6bdf19 100644 --- a/experiments/experiment.jl +++ b/experiments/experiment.jl @@ -21,7 +21,7 @@ Base.@kwdef struct Experiment use_ensembling::Bool = true coverage::Float64 = DEFAULT_COVERAGE generators::Union{Nothing,Dict} = nothing - n_individuals::Int = 50 + n_individuals::Int = 25 ce_measures::AbstractArray = CE_MEASURES model_measures::Dict = MODEL_MEASURES use_class_loss::Bool = false diff --git a/experiments/gmsc.jl b/experiments/gmsc.jl index 132e02f34035df6840f821315485ce9af2641634..c1ca3c2b07dea6626068f4e3ffd7e960ea6e7d79 100644 --- a/experiments/gmsc.jl +++ b/experiments/gmsc.jl @@ -14,4 +14,6 @@ run_experiment( use_ensembling = true, Λ=[0.1, 0.5, 0.5], opt = Flux.Optimise.Descent(0.05), + n_individuals = 10, + use_variants = false, ) \ No newline at end of file diff --git a/experiments/mnist.jl b/experiments/mnist.jl index 557143e98e5f707a8d495bb6cffdeef9593a19ee..1b561b2286d9fdf0367725699a30824eda695ab9 100644 --- a/experiments/mnist.jl +++ b/experiments/mnist.jl @@ -55,4 +55,5 @@ run_experiment( sampling_steps=25, use_ensembling = true, generators = generator_dict, + n_individuals = 5 ) \ No newline at end of file diff --git a/experiments/post_processing.jl b/experiments/post_processing.jl index 7e8f8e63c8b326fa8433511e6ae8b5a8d6b97e80..836723661b41c01efcc3731e0e1ad7f97069983d 100644 --- a/experiments/post_processing.jl +++ b/experiments/post_processing.jl @@ -23,6 +23,7 @@ function meta_model(outcome::ExperimentOutcome; save_output::Bool=false) # Unpack: exp = outcome.exp n_obs, batch_size = meta_data(exp) + model_dict = outcome.model_dict params = DataFrame( Dict( @@ -30,13 +31,13 @@ function meta_model(outcome::ExperimentOutcome; save_output::Bool=false) :batch_size => batch_size, :dataname => exp.dataname, :sgld_batch_size => exp.sampling_batch_size, - # :epochs => exp.epochs, - # :n_hidden => n_hidden, - # :n_layers => length(model_dict["MLP"].fitresult[1][1]) - 1, - # :activation => string(activation), - # :n_ens => n_ens, - # :lambda => string(α[3]), - # :jem_sampling_steps => jem.sampling_steps, + :epochs => exp.epochs, + :n_hidden => exp.n_hidden, + :n_layers => length(model_dict["MLP"].fitresult[1][1]) - 1, + :activation => string(activation), + :n_ens => exp.n_ens, + :lambda => exp.string(exp.α[3]), + :jem_sampling_steps => exp.sampling_steps, ) ) @@ -55,6 +56,8 @@ function meta_generators(outcome::ExperimentOutcome; save_output::Bool=false) # Unpack: exp = outcome.exp generator_dict = outcome.generator_dict + Λ = exp.Λ + Λ_Δ = exp.Λ_Δ # Output: opt = first(values(generator_dict)).opt @@ -63,6 +66,13 @@ function meta_generators(outcome::ExperimentOutcome; save_output::Bool=false) :opt => string(typeof(opt)), :eta => opt.eta, :dataname => dataname, + :lambda_1 => string(Λ[1]), + :lambda_2 => string(Λ[2]), + :lambda_3 => string(Λ[3]), + :lambda_1_Δ => string(Λ_Δ[1]), + :lambda_2_Δ => string(Λ_Δ[2]), + :lambda_3_Δ => string(Λ_Δ[3]), + :n_individuals => exp.n_individuals, ) ) diff --git a/src/generator.jl b/src/generator.jl index b4d986c56b371649a9c7799950bdbbc7887d646e..1939d5655f894776d85ae53807adbdcca8f0b539 100644 --- a/src/generator.jl +++ b/src/generator.jl @@ -26,19 +26,14 @@ function ECCCoGenerator(; loss_fun = nothing end - # Set size penalty - _set_size_penalty = (ce::AbstractCounterfactualExplanation) -> ECCCo.set_size_penalty(ce; κ=κ, temp=temp) + _energy_penalty = + use_energy_delta ? (ECCCo.energy_delta, (n=nsamples, nmin=nmin)) : (ECCCo.distance_from_energy, (n=nsamples, nmin=nmin)) - # Energy penalty - _energy_penalty = function(ce::AbstractCounterfactualExplanation) - if use_energy_delta - return ECCCo.energy_delta(ce; n=nsamples, nmin=nmin) - else - return ECCCo.distance_from_energy(ce; n=nsamples, nmin=nmin) - end - end - - _penalties = [Objectives.distance_l1, _set_size_penalty, _energy_penalty] + _penalties = [ + (Objectives.distance_l1, []), + (ECCCo.set_size_penalty, (κ=κ, temp=temp)), + _energy_penalty, + ] λ = λ isa AbstractFloat ? [0.0, λ, λ] : λ # Generator