diff --git a/experiments/benchmarking/benchmarking.jl b/experiments/benchmarking/benchmarking.jl index f4176e586e35fc4ce4314f9ec083c28eccb0c852..4a27f5e2fcf0d141295762e8c02b89a8aecc366c 100644 --- a/experiments/benchmarking/benchmarking.jl +++ b/experiments/benchmarking/benchmarking.jl @@ -1,4 +1,4 @@ -function default_generators( +function default_generators(; Λ::AbstractArray=[0.25, 0.75, 0.75], Λ_Δ::AbstractArray=Λ, use_variants::Bool=true, @@ -49,10 +49,17 @@ function run_benchmark(exp::Experiment, model_dict::Dict) counterfactual_data = exp.counterfactual_data generator_dict = exp.generators measures = exp.ce_measures + parallelizer = exp.parallelizer # Benchmark generators: if isnothing(generator_dict) - generator_dict = default_generators() + generator_dict = default_generators(; + Λ=exp.Λ, + Λ_Δ=exp.Λ_Δ, + use_variants=exp.use_variants, + use_class_loss=exp.use_class_loss, + opt=exp.opt + ) end # Run benchmark: @@ -72,7 +79,8 @@ function run_benchmark(exp::Experiment, model_dict::Dict) n_individuals=n_individuals, target=target, factual=factual, initialization=:identity, - converge_when=:generator_conditions + converge_when=:generator_conditions, + parallelizer=parallelizer, ) push!(bmks, bmk) end diff --git a/experiments/experiment.jl b/experiments/experiment.jl index 9c6f7f3f7f75000fc636ca093b58f2491336ec04..f3770aa4b87b859fd7a6d4885c983bc64a696e44 100644 --- a/experiments/experiment.jl +++ b/experiments/experiment.jl @@ -17,6 +17,12 @@ Base.@kwdef struct Experiment n_individuals::Int = 50 ce_measures::AbstractArray = CE_MEASURES model_measures::Dict = MODEL_MEASURES + use_class_loss::Bool = true + use_variants::Bool = true + Λ::AbstractArray = [0.25, 0.75, 0.75] + Λ_Δ::AbstractArray = Λ + opt::Flux.Optimise.Optimizer = Flux.Optimise.Descent(0.01) + parallelizer::Union{Nothing, AbstractParallelizer} = nothing end "A container to hold the results of an experiment." diff --git a/experiments/setup_env.jl b/experiments/setup_env.jl index 1ee2e24c2b96fc1087129cd4f547a4ec01db7c7d..564bc50afbef609d7ed4ad292a5a23a7dc2f33de 100644 --- a/experiments/setup_env.jl +++ b/experiments/setup_env.jl @@ -11,6 +11,7 @@ using CounterfactualExplanations.Evaluation: benchmark, evaluate, Benchmark using CounterfactualExplanations.Generators: JSMADescent using CounterfactualExplanations.Models: load_mnist_mlp, load_fashion_mnist_mlp, train, probs using CounterfactualExplanations.Objectives +using CounterfactualExplanations.Parallelization using CSV using DataFrames using Distributions: Normal, Distribution, Categorical