From 9ee71bc1d7fb68cbcff7efb073d099e532909af0 Mon Sep 17 00:00:00 2001 From: Pat Alt <55311242+pat-alt@users.noreply.github.com> Date: Fri, 1 Sep 2023 15:15:01 +0200 Subject: [PATCH] sorted out anonymous function issue --- experiments/experiment.jl | 2 +- experiments/gmsc.jl | 2 ++ experiments/mnist.jl | 1 + experiments/post_processing.jl | 24 +++++++++++++++++------- src/generator.jl | 19 +++++++------------ 5 files changed, 28 insertions(+), 20 deletions(-) diff --git a/experiments/experiment.jl b/experiments/experiment.jl index 469a123e..5faf62f9 100644 --- a/experiments/experiment.jl +++ b/experiments/experiment.jl @@ -21,7 +21,7 @@ Base.@kwdef struct Experiment use_ensembling::Bool = true coverage::Float64 = DEFAULT_COVERAGE generators::Union{Nothing,Dict} = nothing - n_individuals::Int = 50 + n_individuals::Int = 25 ce_measures::AbstractArray = CE_MEASURES model_measures::Dict = MODEL_MEASURES use_class_loss::Bool = false diff --git a/experiments/gmsc.jl b/experiments/gmsc.jl index 132e02f3..c1ca3c2b 100644 --- a/experiments/gmsc.jl +++ b/experiments/gmsc.jl @@ -14,4 +14,6 @@ run_experiment( use_ensembling = true, Λ=[0.1, 0.5, 0.5], opt = Flux.Optimise.Descent(0.05), + n_individuals = 10, + use_variants = false, ) \ No newline at end of file diff --git a/experiments/mnist.jl b/experiments/mnist.jl index 557143e9..1b561b22 100644 --- a/experiments/mnist.jl +++ b/experiments/mnist.jl @@ -55,4 +55,5 @@ run_experiment( sampling_steps=25, use_ensembling = true, generators = generator_dict, + n_individuals = 5 ) \ No newline at end of file diff --git a/experiments/post_processing.jl b/experiments/post_processing.jl index 7e8f8e63..83672366 100644 --- a/experiments/post_processing.jl +++ b/experiments/post_processing.jl @@ -23,6 +23,7 @@ function meta_model(outcome::ExperimentOutcome; save_output::Bool=false) # Unpack: exp = outcome.exp n_obs, batch_size = meta_data(exp) + model_dict = outcome.model_dict params = DataFrame( Dict( @@ -30,13 +31,13 @@ function meta_model(outcome::ExperimentOutcome; save_output::Bool=false) :batch_size => batch_size, :dataname => exp.dataname, :sgld_batch_size => exp.sampling_batch_size, - # :epochs => exp.epochs, - # :n_hidden => n_hidden, - # :n_layers => length(model_dict["MLP"].fitresult[1][1]) - 1, - # :activation => string(activation), - # :n_ens => n_ens, - # :lambda => string(α[3]), - # :jem_sampling_steps => jem.sampling_steps, + :epochs => exp.epochs, + :n_hidden => exp.n_hidden, + :n_layers => length(model_dict["MLP"].fitresult[1][1]) - 1, + :activation => string(activation), + :n_ens => exp.n_ens, + :lambda => exp.string(exp.α[3]), + :jem_sampling_steps => exp.sampling_steps, ) ) @@ -55,6 +56,8 @@ function meta_generators(outcome::ExperimentOutcome; save_output::Bool=false) # Unpack: exp = outcome.exp generator_dict = outcome.generator_dict + Λ = exp.Λ + Λ_Δ = exp.Λ_Δ # Output: opt = first(values(generator_dict)).opt @@ -63,6 +66,13 @@ function meta_generators(outcome::ExperimentOutcome; save_output::Bool=false) :opt => string(typeof(opt)), :eta => opt.eta, :dataname => dataname, + :lambda_1 => string(Λ[1]), + :lambda_2 => string(Λ[2]), + :lambda_3 => string(Λ[3]), + :lambda_1_Δ => string(Λ_Δ[1]), + :lambda_2_Δ => string(Λ_Δ[2]), + :lambda_3_Δ => string(Λ_Δ[3]), + :n_individuals => exp.n_individuals, ) ) diff --git a/src/generator.jl b/src/generator.jl index b4d986c5..1939d565 100644 --- a/src/generator.jl +++ b/src/generator.jl @@ -26,19 +26,14 @@ function ECCCoGenerator(; loss_fun = nothing end - # Set size penalty - _set_size_penalty = (ce::AbstractCounterfactualExplanation) -> ECCCo.set_size_penalty(ce; κ=κ, temp=temp) + _energy_penalty = + use_energy_delta ? (ECCCo.energy_delta, (n=nsamples, nmin=nmin)) : (ECCCo.distance_from_energy, (n=nsamples, nmin=nmin)) - # Energy penalty - _energy_penalty = function(ce::AbstractCounterfactualExplanation) - if use_energy_delta - return ECCCo.energy_delta(ce; n=nsamples, nmin=nmin) - else - return ECCCo.distance_from_energy(ce; n=nsamples, nmin=nmin) - end - end - - _penalties = [Objectives.distance_l1, _set_size_penalty, _energy_penalty] + _penalties = [ + (Objectives.distance_l1, []), + (ECCCo.set_size_penalty, (κ=κ, temp=temp)), + _energy_penalty, + ] λ = λ isa AbstractFloat ? [0.0, λ, λ] : λ # Generator -- GitLab