Skip to content
Snippets Groups Projects
Commit 9ee71bc1 authored by Pat Alt's avatar Pat Alt
Browse files

sorted out anonymous function issue

parent 9221e51f
No related branches found
No related tags found
1 merge request!7669 initial run including fmnist lenet and new method
...@@ -21,7 +21,7 @@ Base.@kwdef struct Experiment ...@@ -21,7 +21,7 @@ Base.@kwdef struct Experiment
use_ensembling::Bool = true use_ensembling::Bool = true
coverage::Float64 = DEFAULT_COVERAGE coverage::Float64 = DEFAULT_COVERAGE
generators::Union{Nothing,Dict} = nothing generators::Union{Nothing,Dict} = nothing
n_individuals::Int = 50 n_individuals::Int = 25
ce_measures::AbstractArray = CE_MEASURES ce_measures::AbstractArray = CE_MEASURES
model_measures::Dict = MODEL_MEASURES model_measures::Dict = MODEL_MEASURES
use_class_loss::Bool = false use_class_loss::Bool = false
......
...@@ -14,4 +14,6 @@ run_experiment( ...@@ -14,4 +14,6 @@ run_experiment(
use_ensembling = true, use_ensembling = true,
Λ=[0.1, 0.5, 0.5], Λ=[0.1, 0.5, 0.5],
opt = Flux.Optimise.Descent(0.05), opt = Flux.Optimise.Descent(0.05),
n_individuals = 10,
use_variants = false,
) )
\ No newline at end of file
...@@ -55,4 +55,5 @@ run_experiment( ...@@ -55,4 +55,5 @@ run_experiment(
sampling_steps=25, sampling_steps=25,
use_ensembling = true, use_ensembling = true,
generators = generator_dict, generators = generator_dict,
n_individuals = 5
) )
\ No newline at end of file
...@@ -23,6 +23,7 @@ function meta_model(outcome::ExperimentOutcome; save_output::Bool=false) ...@@ -23,6 +23,7 @@ function meta_model(outcome::ExperimentOutcome; save_output::Bool=false)
# Unpack: # Unpack:
exp = outcome.exp exp = outcome.exp
n_obs, batch_size = meta_data(exp) n_obs, batch_size = meta_data(exp)
model_dict = outcome.model_dict
params = DataFrame( params = DataFrame(
Dict( Dict(
...@@ -30,13 +31,13 @@ function meta_model(outcome::ExperimentOutcome; save_output::Bool=false) ...@@ -30,13 +31,13 @@ function meta_model(outcome::ExperimentOutcome; save_output::Bool=false)
:batch_size => batch_size, :batch_size => batch_size,
:dataname => exp.dataname, :dataname => exp.dataname,
:sgld_batch_size => exp.sampling_batch_size, :sgld_batch_size => exp.sampling_batch_size,
# :epochs => exp.epochs, :epochs => exp.epochs,
# :n_hidden => n_hidden, :n_hidden => exp.n_hidden,
# :n_layers => length(model_dict["MLP"].fitresult[1][1]) - 1, :n_layers => length(model_dict["MLP"].fitresult[1][1]) - 1,
# :activation => string(activation), :activation => string(activation),
# :n_ens => n_ens, :n_ens => exp.n_ens,
# :lambda => string(α[3]), :lambda => exp.string(exp.α[3]),
# :jem_sampling_steps => jem.sampling_steps, :jem_sampling_steps => exp.sampling_steps,
) )
) )
...@@ -55,6 +56,8 @@ function meta_generators(outcome::ExperimentOutcome; save_output::Bool=false) ...@@ -55,6 +56,8 @@ function meta_generators(outcome::ExperimentOutcome; save_output::Bool=false)
# Unpack: # Unpack:
exp = outcome.exp exp = outcome.exp
generator_dict = outcome.generator_dict generator_dict = outcome.generator_dict
Λ = exp.Λ
Λ_Δ = exp.Λ_Δ
# Output: # Output:
opt = first(values(generator_dict)).opt opt = first(values(generator_dict)).opt
...@@ -63,6 +66,13 @@ function meta_generators(outcome::ExperimentOutcome; save_output::Bool=false) ...@@ -63,6 +66,13 @@ function meta_generators(outcome::ExperimentOutcome; save_output::Bool=false)
:opt => string(typeof(opt)), :opt => string(typeof(opt)),
:eta => opt.eta, :eta => opt.eta,
:dataname => dataname, :dataname => dataname,
:lambda_1 => string(Λ[1]),
:lambda_2 => string(Λ[2]),
:lambda_3 => string(Λ[3]),
:lambda_1_Δ => string(Λ_Δ[1]),
:lambda_2_Δ => string(Λ_Δ[2]),
:lambda_3_Δ => string(Λ_Δ[3]),
:n_individuals => exp.n_individuals,
) )
) )
......
...@@ -26,19 +26,14 @@ function ECCCoGenerator(; ...@@ -26,19 +26,14 @@ function ECCCoGenerator(;
loss_fun = nothing loss_fun = nothing
end end
# Set size penalty _energy_penalty =
_set_size_penalty = (ce::AbstractCounterfactualExplanation) -> ECCCo.set_size_penalty(ce; κ=κ, temp=temp) use_energy_delta ? (ECCCo.energy_delta, (n=nsamples, nmin=nmin)) : (ECCCo.distance_from_energy, (n=nsamples, nmin=nmin))
# Energy penalty _penalties = [
_energy_penalty = function(ce::AbstractCounterfactualExplanation) (Objectives.distance_l1, []),
if use_energy_delta (ECCCo.set_size_penalty, (κ=κ, temp=temp)),
return ECCCo.energy_delta(ce; n=nsamples, nmin=nmin) _energy_penalty,
else ]
return ECCCo.distance_from_energy(ce; n=nsamples, nmin=nmin)
end
end
_penalties = [Objectives.distance_l1, _set_size_penalty, _energy_penalty]
λ = λ isa AbstractFloat ? [0.0, λ, λ] : λ λ = λ isa AbstractFloat ? [0.0, λ, λ] : λ
# Generator # Generator
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment