Skip to content
Snippets Groups Projects
Commit 726906e0 authored by Pat Alt's avatar Pat Alt
Browse files

reg strength

parent 038dda98
No related branches found
No related tags found
1 merge request!7669 initial run including fmnist lenet and new method
...@@ -6,6 +6,7 @@ function default_generators(; ...@@ -6,6 +6,7 @@ function default_generators(;
opt=Flux.Optimise.Descent(0.01), opt=Flux.Optimise.Descent(0.01),
nsamples::Union{Nothing,Int}=nothing, nsamples::Union{Nothing,Int}=nothing,
nmin::Union{Nothing,Int}=nothing, nmin::Union{Nothing,Int}=nothing,
reg_strength::Real=0.5,
) )
@info "Begin benchmarking counterfactual explanations." @info "Begin benchmarking counterfactual explanations."
...@@ -20,9 +21,9 @@ function default_generators(; ...@@ -20,9 +21,9 @@ function default_generators(;
"ECCCo" => ECCCoGenerator(λ=Λ, opt=opt, use_class_loss=use_class_loss, nsamples=nsamples, nmin=nmin), "ECCCo" => ECCCoGenerator(λ=Λ, opt=opt, use_class_loss=use_class_loss, nsamples=nsamples, nmin=nmin),
"ECCCo (no CP)" => ECCCoGenerator(λ=[λ₁, 0.0, λ₃], opt=opt, use_class_loss=use_class_loss, nsamples=nsamples, nmin=nmin), "ECCCo (no CP)" => ECCCoGenerator(λ=[λ₁, 0.0, λ₃], opt=opt, use_class_loss=use_class_loss, nsamples=nsamples, nmin=nmin),
"ECCCo (no EBM)" => ECCCoGenerator(λ=[λ₁, λ₂, 0.0], opt=opt, use_class_loss=use_class_loss, nsamples=nsamples, nmin=nmin), "ECCCo (no EBM)" => ECCCoGenerator(λ=[λ₁, λ₂, 0.0], opt=opt, use_class_loss=use_class_loss, nsamples=nsamples, nmin=nmin),
"ECCCo-Δ" => ECCCoGenerator(λ=Λ_Δ, opt=opt, use_class_loss=use_class_loss, use_energy_delta=true, nsamples=nsamples, nmin=nmin), "ECCCo-Δ" => ECCCoGenerator(λ=Λ_Δ, opt=opt, use_class_loss=use_class_loss, use_energy_delta=true, nsamples=nsamples, nmin=nmin, reg_strength=reg_strength),
"ECCCo-Δ (no CP)" => ECCCoGenerator(λ=[λ₁_Δ, 0.0, λ₃_Δ], opt=opt, use_class_loss=use_class_loss, use_energy_delta=true, nsamples=nsamples, nmin=nmin), "ECCCo-Δ (no CP)" => ECCCoGenerator(λ=[λ₁_Δ, 0.0, λ₃_Δ], opt=opt, use_class_loss=use_class_loss, use_energy_delta=true, nsamples=nsamples, nmin=nmin, reg_strength=reg_strength),
"ECCCo-Δ (no EBM)" => ECCCoGenerator(λ=[λ₁_Δ, λ₂_Δ, 0.0], opt=opt, use_class_loss=use_class_loss, use_energy_delta=true, nsamples=nsamples, nmin=nmin), "ECCCo-Δ (no EBM)" => ECCCoGenerator(λ=[λ₁_Δ, λ₂_Δ, 0.0], opt=opt, use_class_loss=use_class_loss, use_energy_delta=true, nsamples=nsamples, nmin=nmin, reg_strength=reg_strength),
) )
else else
generator_dict = Dict( generator_dict = Dict(
...@@ -30,7 +31,7 @@ function default_generators(; ...@@ -30,7 +31,7 @@ function default_generators(;
"REVISE" => REVISEGenerator(λ=λ₁, opt=opt), "REVISE" => REVISEGenerator(λ=λ₁, opt=opt),
"Schut" => GreedyGenerator(), "Schut" => GreedyGenerator(),
"ECCCo" => ECCCoGenerator(λ=Λ, opt=opt, use_class_loss=use_class_loss, nsamples=nsamples, nmin=nmin), "ECCCo" => ECCCoGenerator(λ=Λ, opt=opt, use_class_loss=use_class_loss, nsamples=nsamples, nmin=nmin),
"ECCCo-Δ" => ECCCoGenerator(λ=Λ_Δ, opt=opt, use_class_loss=use_class_loss, use_energy_delta=true, nsamples=nsamples, nmin=nmin), "ECCCo-Δ" => ECCCoGenerator(λ=Λ_Δ, opt=opt, use_class_loss=use_class_loss, use_energy_delta=true, nsamples=nsamples, nmin=nmin, reg_strength=reg_strength),
) )
end end
return generator_dict return generator_dict
...@@ -63,6 +64,7 @@ function run_benchmark(exper::Experiment, model_dict::Dict) ...@@ -63,6 +64,7 @@ function run_benchmark(exper::Experiment, model_dict::Dict)
opt=exper.opt, opt=exper.opt,
nsamples=exper.nsamples, nsamples=exper.nsamples,
nmin=exper.nmin, nmin=exper.nmin,
reg_strength=exper.reg_strength,
) )
end end
......
counterfactual_data, test_data = train_test_split(load_credit_default(nothing); test_size=TEST_SIZE)
# Default builder:
n_hidden = 128
activation = Flux.swish
builder = MLJFlux.@builder Flux.Chain(
Dense(n_in, n_hidden, activation),
Dense(n_hidden, n_hidden, activation),
Dense(n_hidden, n_out),
)
# Number of individuals:
n_ind = N_IND_SPECIFIED ? N_IND : 10
run_experiment(
counterfactual_data, test_data;
dataname="Credit Default",
builder=builder,
α=[1.0, 1.0, 1e-1],
sampling_batch_size=10,
sampling_steps=30,
use_ensembling=true,
Λ=[0.1, 0.5, 0.5],
opt=Flux.Optimise.Descent(0.05),
n_individuals=n_ind,
use_variants=false
)
\ No newline at end of file
...@@ -36,6 +36,7 @@ Base.@kwdef struct Experiment ...@@ -36,6 +36,7 @@ Base.@kwdef struct Experiment
finaliser::Function = Flux.softmax finaliser::Function = Flux.softmax
loss::Function = Flux.Losses.crossentropy loss::Function = Flux.Losses.crossentropy
train_parallel::Bool = false train_parallel::Bool = false
reg_strength::Real = 0.5
end end
"A container to hold the results of an experiment." "A container to hold the results of an experiment."
......
...@@ -11,4 +11,4 @@ ...@@ -11,4 +11,4 @@
module load 2023r1 openmpi module load 2023r1 openmpi
srun julia --project=experiments experiments/run_experiments.jl -- data=linearly_separable,moons,circles output_path=results retrain threaded mpi > experiments/synthetic.log srun julia --project=experiments --threads 4 experiments/run_experiments.jl -- data=linearly_separable,moons,circles output_path=results retrain threaded mpi > experiments/synthetic.log
...@@ -11,6 +11,7 @@ function ECCCoGenerator(; ...@@ -11,6 +11,7 @@ function ECCCoGenerator(;
use_energy_delta::Bool=false, use_energy_delta::Bool=false,
nsamples::Union{Nothing,Int}=nothing, nsamples::Union{Nothing,Int}=nothing,
nmin::Union{Nothing,Int}=nothing, nmin::Union{Nothing,Int}=nothing,
reg_strength::Real=0.5,
kwargs... kwargs...
) )
...@@ -30,7 +31,7 @@ function ECCCoGenerator(; ...@@ -30,7 +31,7 @@ function ECCCoGenerator(;
end end
_energy_penalty = _energy_penalty =
use_energy_delta ? (ECCCo.energy_delta, (n=nsamples, nmin=nmin)) : (ECCCo.distance_from_energy, (n=nsamples, nmin=nmin)) use_energy_delta ? (ECCCo.energy_delta, (n=nsamples, nmin=nmin, reg_strength=reg_strength)) : (ECCCo.distance_from_energy, (n=nsamples, nmin=nmin))
_penalties = [ _penalties = [
(Objectives.distance_l1, []), (Objectives.distance_l1, []),
......
...@@ -46,10 +46,10 @@ function energy_delta( ...@@ -46,10 +46,10 @@ function energy_delta(
choose_random=false, choose_random=false,
nmin::Int=25, nmin::Int=25,
return_conditionals=false, return_conditionals=false,
reg_strength=0.5,
kwargs... kwargs...
) )
_loss = 0.0
nmin = minimum([nmin, n]) nmin = minimum([nmin, n])
@assert choose_lowest_energy choose_random || !choose_lowest_energy && !choose_random "Must choose either lowest energy or random samples or neither." @assert choose_lowest_energy choose_random || !choose_lowest_energy && !choose_random "Must choose either lowest energy or random samples or neither."
...@@ -88,7 +88,7 @@ function energy_delta( ...@@ -88,7 +88,7 @@ function energy_delta(
if return_conditionals if return_conditionals
return conditional_samples[1] return conditional_samples[1]
end end
return gen_loss + 0.1reg_loss return gen_loss + reg_strength * reg_loss
end end
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment