Skip to content
Snippets Groups Projects
Commit 9ee71bc1 authored by Pat Alt's avatar Pat Alt
Browse files

sorted out anonymous function issue

parent 9221e51f
No related branches found
No related tags found
1 merge request!7669 initial run including fmnist lenet and new method
......@@ -21,7 +21,7 @@ Base.@kwdef struct Experiment
use_ensembling::Bool = true
coverage::Float64 = DEFAULT_COVERAGE
generators::Union{Nothing,Dict} = nothing
n_individuals::Int = 50
n_individuals::Int = 25
ce_measures::AbstractArray = CE_MEASURES
model_measures::Dict = MODEL_MEASURES
use_class_loss::Bool = false
......
......@@ -14,4 +14,6 @@ run_experiment(
use_ensembling = true,
Λ=[0.1, 0.5, 0.5],
opt = Flux.Optimise.Descent(0.05),
n_individuals = 10,
use_variants = false,
)
\ No newline at end of file
......@@ -55,4 +55,5 @@ run_experiment(
sampling_steps=25,
use_ensembling = true,
generators = generator_dict,
n_individuals = 5
)
\ No newline at end of file
......@@ -23,6 +23,7 @@ function meta_model(outcome::ExperimentOutcome; save_output::Bool=false)
# Unpack:
exp = outcome.exp
n_obs, batch_size = meta_data(exp)
model_dict = outcome.model_dict
params = DataFrame(
Dict(
......@@ -30,13 +31,13 @@ function meta_model(outcome::ExperimentOutcome; save_output::Bool=false)
:batch_size => batch_size,
:dataname => exp.dataname,
:sgld_batch_size => exp.sampling_batch_size,
# :epochs => exp.epochs,
# :n_hidden => n_hidden,
# :n_layers => length(model_dict["MLP"].fitresult[1][1]) - 1,
# :activation => string(activation),
# :n_ens => n_ens,
# :lambda => string(α[3]),
# :jem_sampling_steps => jem.sampling_steps,
:epochs => exp.epochs,
:n_hidden => exp.n_hidden,
:n_layers => length(model_dict["MLP"].fitresult[1][1]) - 1,
:activation => string(activation),
:n_ens => exp.n_ens,
:lambda => exp.string(exp.α[3]),
:jem_sampling_steps => exp.sampling_steps,
)
)
......@@ -55,6 +56,8 @@ function meta_generators(outcome::ExperimentOutcome; save_output::Bool=false)
# Unpack:
exp = outcome.exp
generator_dict = outcome.generator_dict
Λ = exp.Λ
Λ_Δ = exp.Λ_Δ
# Output:
opt = first(values(generator_dict)).opt
......@@ -63,6 +66,13 @@ function meta_generators(outcome::ExperimentOutcome; save_output::Bool=false)
:opt => string(typeof(opt)),
:eta => opt.eta,
:dataname => dataname,
:lambda_1 => string(Λ[1]),
:lambda_2 => string(Λ[2]),
:lambda_3 => string(Λ[3]),
:lambda_1_Δ => string(Λ_Δ[1]),
:lambda_2_Δ => string(Λ_Δ[2]),
:lambda_3_Δ => string(Λ_Δ[3]),
:n_individuals => exp.n_individuals,
)
)
......
......@@ -26,19 +26,14 @@ function ECCCoGenerator(;
loss_fun = nothing
end
# Set size penalty
_set_size_penalty = (ce::AbstractCounterfactualExplanation) -> ECCCo.set_size_penalty(ce; κ=κ, temp=temp)
_energy_penalty =
use_energy_delta ? (ECCCo.energy_delta, (n=nsamples, nmin=nmin)) : (ECCCo.distance_from_energy, (n=nsamples, nmin=nmin))
# Energy penalty
_energy_penalty = function(ce::AbstractCounterfactualExplanation)
if use_energy_delta
return ECCCo.energy_delta(ce; n=nsamples, nmin=nmin)
else
return ECCCo.distance_from_energy(ce; n=nsamples, nmin=nmin)
end
end
_penalties = [Objectives.distance_l1, _set_size_penalty, _energy_penalty]
_penalties = [
(Objectives.distance_l1, []),
(ECCCo.set_size_penalty, (κ=κ, temp=temp)),
_energy_penalty,
]
λ = λ isa AbstractFloat ? [0.0, λ, λ] : λ
# Generator
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment