Skip to content
Snippets Groups Projects
Commit d09efd11 authored by Pat Alt's avatar Pat Alt
Browse files

matame

parent 4dfb424e
No related branches found
No related tags found
1 merge request!7669 initial run including fmnist lenet and new method
......@@ -4,6 +4,7 @@ function default_generators(;
use_variants::Bool=true,
use_class_loss::Bool=false,
opt=Flux.Optimise.Descent(0.01),
niter_eccco::Union{Nothing,Int}=nothing,
nsamples::Union{Nothing,Int}=nothing,
nmin::Union{Nothing,Int}=nothing,
reg_strength::Real=0.5,
......@@ -18,20 +19,20 @@ function default_generators(;
"Wachter" => WachterGenerator(λ=λ₁, opt=opt),
"REVISE" => REVISEGenerator(λ=λ₁, opt=opt),
"Schut" => GreedyGenerator(),
"ECCCo" => ECCCoGenerator(λ=Λ, opt=opt, use_class_loss=use_class_loss, nsamples=nsamples, nmin=nmin),
"ECCCo (no CP)" => ECCCoGenerator(λ=[λ₁, 0.0, λ₃], opt=opt, use_class_loss=use_class_loss, nsamples=nsamples, nmin=nmin),
"ECCCo (no EBM)" => ECCCoGenerator(λ=[λ₁, λ₂, 0.0], opt=opt, use_class_loss=use_class_loss, nsamples=nsamples, nmin=nmin),
"ECCCo-Δ" => ECCCoGenerator(λ=Λ_Δ, opt=opt, use_class_loss=use_class_loss, use_energy_delta=true, nsamples=nsamples, nmin=nmin, reg_strength=reg_strength),
"ECCCo-Δ (no CP)" => ECCCoGenerator(λ=[λ₁_Δ, 0.0, λ₃_Δ], opt=opt, use_class_loss=use_class_loss, use_energy_delta=true, nsamples=nsamples, nmin=nmin, reg_strength=reg_strength),
"ECCCo-Δ (no EBM)" => ECCCoGenerator(λ=[λ₁_Δ, λ₂_Δ, 0.0], opt=opt, use_class_loss=use_class_loss, use_energy_delta=true, nsamples=nsamples, nmin=nmin, reg_strength=reg_strength),
"ECCCo" => ECCCoGenerator(λ=Λ, opt=opt, use_class_loss=use_class_loss, nsamples=nsamples, nmin=nmin, niter=niter_eccco),
"ECCCo (no CP)" => ECCCoGenerator(λ=[λ₁, 0.0, λ₃], opt=opt, use_class_loss=use_class_loss, nsamples=nsamples, nmin=nmin, niter=niter_eccco),
"ECCCo (no EBM)" => ECCCoGenerator(λ=[λ₁, λ₂, 0.0], opt=opt, use_class_loss=use_class_loss, nsamples=nsamples, nmin=nmin, niter=niter_eccco),
"ECCCo-Δ" => ECCCoGenerator(λ=Λ_Δ, opt=opt, use_class_loss=use_class_loss, use_energy_delta=true, nsamples=nsamples, nmin=nmin, niter=niter_eccco, reg_strength = reg_strength),
"ECCCo-Δ (no CP)" => ECCCoGenerator(λ=[λ₁_Δ, 0.0, λ₃_Δ], opt=opt, use_class_loss=use_class_loss, use_energy_delta=true, nsamples=nsamples, nmin=nmin, niter=niter_eccco, reg_strength=reg_strength),
"ECCCo-Δ (no EBM)" => ECCCoGenerator(λ=[λ₁_Δ, λ₂_Δ, 0.0], opt=opt, use_class_loss=use_class_loss, use_energy_delta=true, nsamples=nsamples, nmin=nmin, niter=niter_eccco, reg_strength=reg_strength),
)
else
generator_dict = Dict(
"Wachter" => WachterGenerator(λ=λ₁, opt=opt),
"REVISE" => REVISEGenerator(λ=λ₁, opt=opt),
"Schut" => GreedyGenerator(),
"ECCCo" => ECCCoGenerator(λ=Λ, opt=opt, use_class_loss=use_class_loss, nsamples=nsamples, nmin=nmin),
"ECCCo-Δ" => ECCCoGenerator(λ=Λ_Δ, opt=opt, use_class_loss=use_class_loss, use_energy_delta=true, nsamples=nsamples, nmin=nmin, reg_strength=reg_strength),
"ECCCo" => ECCCoGenerator(λ=Λ, opt=opt, use_class_loss=use_class_loss, nsamples=nsamples, nmin=nmin, niter=niter_eccco),
"ECCCo-Δ" => ECCCoGenerator(λ=Λ_Δ, opt=opt, use_class_loss=use_class_loss, use_energy_delta=true, nsamples=nsamples, nmin=nmin, niter=niter_eccco, reg_strength=reg_strength),
)
end
return generator_dict
......
......@@ -37,6 +37,7 @@ Base.@kwdef struct Experiment
loss::Function = Flux.Losses.crossentropy
train_parallel::Bool = false
reg_strength::Real = 0.1
niter_eccco::Union{Nothing,Int} = nothing
end
"A container to hold the results of an experiment."
......
......@@ -27,6 +27,7 @@ run_experiment(
n_individuals = n_ind,
use_variants = false,
min_batch_size = 250,
nsamples = 10,
nmin = 10,
nsamples = 100,
nmin = 1,
niter_eccco = 30,
)
\ No newline at end of file
......@@ -11,12 +11,15 @@ function ECCCoGenerator(;
use_energy_delta::Bool=false,
nsamples::Union{Nothing,Int}=nothing,
nmin::Union{Nothing,Int}=nothing,
reg_strength::Real=0.5,
niter::Union{Nothing,Int}=nothing,
reg_strength::Real=0.1,
kwargs...
)
# Default ECCCo parameters
nsamples = isnothing(nsamples) ? 50 : nsamples
nmin = isnothing(nmin) ? 25 : nmin
niter = isnothing(niter) ? 500 : niter
# Default optimiser
if isnothing(opt)
......@@ -31,7 +34,7 @@ function ECCCoGenerator(;
end
_energy_penalty =
use_energy_delta ? (ECCCo.energy_delta, (n=nsamples, nmin=nmin, reg_strength=reg_strength)) : (ECCCo.distance_from_energy, (n=nsamples, nmin=nmin))
use_energy_delta ? (ECCCo.energy_delta, (n=nsamples, nmin=nmin, niter=niter, reg_strength=reg_strength)) : (ECCCo.distance_from_energy, (n=nsamples, nmin=nmin, niter=niter))
_penalties = [
(Objectives.distance_l1, []),
......
......@@ -78,8 +78,7 @@ function energy_delta(
_dict[:energy_sampler] = ECCCo.EnergySampler(ce; niter=niter, nsamples=n, kwargs...)
end
eng_sampler = _dict[:energy_sampler]
generate_samples!(eng_sampler, ce.num_counterfactuals, get_target_index(ce.data.y_levels, ce.target); niter=niter)
xsampled = eng_sampler.buffer[:,(end-ce.num_counterfactuals+1):end]
xsampled = rand(eng_sampler, ce.num_counterfactuals; from_buffer=from_buffer)
push!(conditional_samples, xsampled)
end
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment