Skip to content
Snippets Groups Projects
Commit 804c8442 authored by Pat Alt's avatar Pat Alt
Browse files

added parallelization option

parent cd63eee5
No related branches found
No related tags found
1 merge request!7669 initial run including fmnist lenet and new method
function default_generators( function default_generators(;
Λ::AbstractArray=[0.25, 0.75, 0.75], Λ::AbstractArray=[0.25, 0.75, 0.75],
Λ_Δ::AbstractArray=Λ, Λ_Δ::AbstractArray=Λ,
use_variants::Bool=true, use_variants::Bool=true,
...@@ -49,10 +49,17 @@ function run_benchmark(exp::Experiment, model_dict::Dict) ...@@ -49,10 +49,17 @@ function run_benchmark(exp::Experiment, model_dict::Dict)
counterfactual_data = exp.counterfactual_data counterfactual_data = exp.counterfactual_data
generator_dict = exp.generators generator_dict = exp.generators
measures = exp.ce_measures measures = exp.ce_measures
parallelizer = exp.parallelizer
# Benchmark generators: # Benchmark generators:
if isnothing(generator_dict) if isnothing(generator_dict)
generator_dict = default_generators() generator_dict = default_generators(;
Λ=exp.Λ,
Λ_Δ=exp.Λ_Δ,
use_variants=exp.use_variants,
use_class_loss=exp.use_class_loss,
opt=exp.opt
)
end end
# Run benchmark: # Run benchmark:
...@@ -72,7 +79,8 @@ function run_benchmark(exp::Experiment, model_dict::Dict) ...@@ -72,7 +79,8 @@ function run_benchmark(exp::Experiment, model_dict::Dict)
n_individuals=n_individuals, n_individuals=n_individuals,
target=target, factual=factual, target=target, factual=factual,
initialization=:identity, initialization=:identity,
converge_when=:generator_conditions converge_when=:generator_conditions,
parallelizer=parallelizer,
) )
push!(bmks, bmk) push!(bmks, bmk)
end end
......
...@@ -17,6 +17,12 @@ Base.@kwdef struct Experiment ...@@ -17,6 +17,12 @@ Base.@kwdef struct Experiment
n_individuals::Int = 50 n_individuals::Int = 50
ce_measures::AbstractArray = CE_MEASURES ce_measures::AbstractArray = CE_MEASURES
model_measures::Dict = MODEL_MEASURES model_measures::Dict = MODEL_MEASURES
use_class_loss::Bool = true
use_variants::Bool = true
Λ::AbstractArray = [0.25, 0.75, 0.75]
Λ_Δ::AbstractArray = Λ
opt::Flux.Optimise.Optimizer = Flux.Optimise.Descent(0.01)
parallelizer::Union{Nothing, AbstractParallelizer} = nothing
end end
"A container to hold the results of an experiment." "A container to hold the results of an experiment."
......
...@@ -11,6 +11,7 @@ using CounterfactualExplanations.Evaluation: benchmark, evaluate, Benchmark ...@@ -11,6 +11,7 @@ using CounterfactualExplanations.Evaluation: benchmark, evaluate, Benchmark
using CounterfactualExplanations.Generators: JSMADescent using CounterfactualExplanations.Generators: JSMADescent
using CounterfactualExplanations.Models: load_mnist_mlp, load_fashion_mnist_mlp, train, probs using CounterfactualExplanations.Models: load_mnist_mlp, load_fashion_mnist_mlp, train, probs
using CounterfactualExplanations.Objectives using CounterfactualExplanations.Objectives
using CounterfactualExplanations.Parallelization
using CSV using CSV
using DataFrames using DataFrames
using Distributions: Normal, Distribution, Categorical using Distributions: Normal, Distribution, Categorical
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment