From d09efd11fa2d0c51cc20fe0eb854127d83ea7a23 Mon Sep 17 00:00:00 2001 From: Pat Alt <55311242+pat-alt@users.noreply.github.com> Date: Wed, 13 Sep 2023 10:41:54 +0200 Subject: [PATCH] matame --- experiments/benchmarking/benchmarking.jl | 17 +++++++++-------- experiments/experiment.jl | 1 + experiments/gmsc.jl | 5 +++-- src/generator.jl | 7 +++++-- src/penalties.jl | 3 +-- 5 files changed, 19 insertions(+), 14 deletions(-) diff --git a/experiments/benchmarking/benchmarking.jl b/experiments/benchmarking/benchmarking.jl index 64f864b3..ddfacbce 100644 --- a/experiments/benchmarking/benchmarking.jl +++ b/experiments/benchmarking/benchmarking.jl @@ -4,6 +4,7 @@ function default_generators(; use_variants::Bool=true, use_class_loss::Bool=false, opt=Flux.Optimise.Descent(0.01), + niter_eccco::Union{Nothing,Int}=nothing, nsamples::Union{Nothing,Int}=nothing, nmin::Union{Nothing,Int}=nothing, reg_strength::Real=0.5, @@ -18,20 +19,20 @@ function default_generators(; "Wachter" => WachterGenerator(λ=λâ‚, opt=opt), "REVISE" => REVISEGenerator(λ=λâ‚, opt=opt), "Schut" => GreedyGenerator(), - "ECCCo" => ECCCoGenerator(λ=Λ, opt=opt, use_class_loss=use_class_loss, nsamples=nsamples, nmin=nmin), - "ECCCo (no CP)" => ECCCoGenerator(λ=[λâ‚, 0.0, λ₃], opt=opt, use_class_loss=use_class_loss, nsamples=nsamples, nmin=nmin), - "ECCCo (no EBM)" => ECCCoGenerator(λ=[λâ‚, λ₂, 0.0], opt=opt, use_class_loss=use_class_loss, nsamples=nsamples, nmin=nmin), - "ECCCo-Δ" => ECCCoGenerator(λ=Λ_Δ, opt=opt, use_class_loss=use_class_loss, use_energy_delta=true, nsamples=nsamples, nmin=nmin, reg_strength=reg_strength), - "ECCCo-Δ (no CP)" => ECCCoGenerator(λ=[λâ‚_Δ, 0.0, λ₃_Δ], opt=opt, use_class_loss=use_class_loss, use_energy_delta=true, nsamples=nsamples, nmin=nmin, reg_strength=reg_strength), - "ECCCo-Δ (no EBM)" => ECCCoGenerator(λ=[λâ‚_Δ, λ₂_Δ, 0.0], opt=opt, use_class_loss=use_class_loss, use_energy_delta=true, nsamples=nsamples, nmin=nmin, reg_strength=reg_strength), + "ECCCo" => ECCCoGenerator(λ=Λ, opt=opt, use_class_loss=use_class_loss, nsamples=nsamples, nmin=nmin, niter=niter_eccco), + "ECCCo (no CP)" => ECCCoGenerator(λ=[λâ‚, 0.0, λ₃], opt=opt, use_class_loss=use_class_loss, nsamples=nsamples, nmin=nmin, niter=niter_eccco), + "ECCCo (no EBM)" => ECCCoGenerator(λ=[λâ‚, λ₂, 0.0], opt=opt, use_class_loss=use_class_loss, nsamples=nsamples, nmin=nmin, niter=niter_eccco), + "ECCCo-Δ" => ECCCoGenerator(λ=Λ_Δ, opt=opt, use_class_loss=use_class_loss, use_energy_delta=true, nsamples=nsamples, nmin=nmin, niter=niter_eccco, reg_strength = reg_strength), + "ECCCo-Δ (no CP)" => ECCCoGenerator(λ=[λâ‚_Δ, 0.0, λ₃_Δ], opt=opt, use_class_loss=use_class_loss, use_energy_delta=true, nsamples=nsamples, nmin=nmin, niter=niter_eccco, reg_strength=reg_strength), + "ECCCo-Δ (no EBM)" => ECCCoGenerator(λ=[λâ‚_Δ, λ₂_Δ, 0.0], opt=opt, use_class_loss=use_class_loss, use_energy_delta=true, nsamples=nsamples, nmin=nmin, niter=niter_eccco, reg_strength=reg_strength), ) else generator_dict = Dict( "Wachter" => WachterGenerator(λ=λâ‚, opt=opt), "REVISE" => REVISEGenerator(λ=λâ‚, opt=opt), "Schut" => GreedyGenerator(), - "ECCCo" => ECCCoGenerator(λ=Λ, opt=opt, use_class_loss=use_class_loss, nsamples=nsamples, nmin=nmin), - "ECCCo-Δ" => ECCCoGenerator(λ=Λ_Δ, opt=opt, use_class_loss=use_class_loss, use_energy_delta=true, nsamples=nsamples, nmin=nmin, reg_strength=reg_strength), + "ECCCo" => ECCCoGenerator(λ=Λ, opt=opt, use_class_loss=use_class_loss, nsamples=nsamples, nmin=nmin, niter=niter_eccco), + "ECCCo-Δ" => ECCCoGenerator(λ=Λ_Δ, opt=opt, use_class_loss=use_class_loss, use_energy_delta=true, nsamples=nsamples, nmin=nmin, niter=niter_eccco, reg_strength=reg_strength), ) end return generator_dict diff --git a/experiments/experiment.jl b/experiments/experiment.jl index 46a81a1d..b161df03 100644 --- a/experiments/experiment.jl +++ b/experiments/experiment.jl @@ -37,6 +37,7 @@ Base.@kwdef struct Experiment loss::Function = Flux.Losses.crossentropy train_parallel::Bool = false reg_strength::Real = 0.1 + niter_eccco::Union{Nothing,Int} = nothing end "A container to hold the results of an experiment." diff --git a/experiments/gmsc.jl b/experiments/gmsc.jl index 31e6ced5..74ff1bb3 100644 --- a/experiments/gmsc.jl +++ b/experiments/gmsc.jl @@ -27,6 +27,7 @@ run_experiment( n_individuals = n_ind, use_variants = false, min_batch_size = 250, - nsamples = 10, - nmin = 10, + nsamples = 100, + nmin = 1, + niter_eccco = 30, ) \ No newline at end of file diff --git a/src/generator.jl b/src/generator.jl index 5a1117b8..b52fa275 100644 --- a/src/generator.jl +++ b/src/generator.jl @@ -11,12 +11,15 @@ function ECCCoGenerator(; use_energy_delta::Bool=false, nsamples::Union{Nothing,Int}=nothing, nmin::Union{Nothing,Int}=nothing, - reg_strength::Real=0.5, + niter::Union{Nothing,Int}=nothing, + reg_strength::Real=0.1, kwargs... ) + # Default ECCCo parameters nsamples = isnothing(nsamples) ? 50 : nsamples nmin = isnothing(nmin) ? 25 : nmin + niter = isnothing(niter) ? 500 : niter # Default optimiser if isnothing(opt) @@ -31,7 +34,7 @@ function ECCCoGenerator(; end _energy_penalty = - use_energy_delta ? (ECCCo.energy_delta, (n=nsamples, nmin=nmin, reg_strength=reg_strength)) : (ECCCo.distance_from_energy, (n=nsamples, nmin=nmin)) + use_energy_delta ? (ECCCo.energy_delta, (n=nsamples, nmin=nmin, niter=niter, reg_strength=reg_strength)) : (ECCCo.distance_from_energy, (n=nsamples, nmin=nmin, niter=niter)) _penalties = [ (Objectives.distance_l1, []), diff --git a/src/penalties.jl b/src/penalties.jl index 7c57a067..9df579b5 100644 --- a/src/penalties.jl +++ b/src/penalties.jl @@ -78,8 +78,7 @@ function energy_delta( _dict[:energy_sampler] = ECCCo.EnergySampler(ce; niter=niter, nsamples=n, kwargs...) end eng_sampler = _dict[:energy_sampler] - generate_samples!(eng_sampler, ce.num_counterfactuals, get_target_index(ce.data.y_levels, ce.target); niter=niter) - xsampled = eng_sampler.buffer[:,(end-ce.num_counterfactuals+1):end] + xsampled = rand(eng_sampler, ce.num_counterfactuals; from_buffer=from_buffer) push!(conditional_samples, xsampled) end -- GitLab