diff --git a/experiments/circles.jl b/experiments/circles.jl index 6cb3a38a79c06280507aab8d1c82dc88b9195d62..110c22f7b57ce8048ebd5b52e945b250bfcf76d1 100644 --- a/experiments/circles.jl +++ b/experiments/circles.jl @@ -1,14 +1,12 @@ n_obs = Int(1000 / (1.0 - TEST_SIZE)) counterfactual_data, test_data = train_test_split(load_circles(n_obs; noise=0.05, factor=0.5); test_size=TEST_SIZE) run_experiment( - counterfactual_data, test_data; dataname="Circles", + counterfactual_data, test_data; + dataname="Circles", n_hidden=32, α=[1.0, 1.0, 1e-2], sampling_batch_size=nothing, sampling_steps=20, - λâ‚=0.25, - λ₂ = 0.75, - λ₃ = 0.75, + Λ=[0.25, 0.75, 0.75], opt=Flux.Optimise.Descent(0.01), - use_class_loss = false, ) \ No newline at end of file diff --git a/experiments/experiment.jl b/experiments/experiment.jl index 60df51485aeb706c1361949de27c91492bd0c94c..469a123ee9f69e20f403476d629f1e1647bb8df2 100644 --- a/experiments/experiment.jl +++ b/experiments/experiment.jl @@ -11,13 +11,20 @@ Base.@kwdef struct Experiment builder::Union{Nothing,MLJFlux.Builder} = nothing ð’Ÿx::Distribution = Normal() sampling_batch_size::Int = 50 + sampling_steps::Int = 50 min_batch_size::Int = 128 + epochs::Int = 100 + n_hidden::Int = 32 + activation::Function = Flux.relu + α::AbstractArray = [1.0, 1.0, 1e-1] + n_ens::Int = 5 + use_ensembling::Bool = true coverage::Float64 = DEFAULT_COVERAGE generators::Union{Nothing,Dict} = nothing n_individuals::Int = 50 ce_measures::AbstractArray = CE_MEASURES model_measures::Dict = MODEL_MEASURES - use_class_loss::Bool = true + use_class_loss::Bool = false use_variants::Bool = true Λ::AbstractArray = [0.25, 0.75, 0.75] Λ_Δ::AbstractArray = Λ diff --git a/experiments/gmsc.jl b/experiments/gmsc.jl index 8b53e93e2ed80a6a1c0630249445c5b60a496f6d..132e02f34035df6840f821315485ce9af2641634 100644 --- a/experiments/gmsc.jl +++ b/experiments/gmsc.jl @@ -12,10 +12,6 @@ run_experiment( sampling_batch_size=nothing, sampling_steps = 30, use_ensembling = true, - λ₠= 0.1, - λ₂ = 0.5, - λ₃ = 0.5, + Λ=[0.1, 0.5, 0.5], opt = Flux.Optimise.Descent(0.05), - use_class_loss=false, - use_variants=false, ) \ No newline at end of file diff --git a/experiments/linearly_separable.jl b/experiments/linearly_separable.jl index 4adc399afd84011de4c640b691c202a29887203a..4fcf9ea780394504dcf4314417c14deb311e8c0b 100644 --- a/experiments/linearly_separable.jl +++ b/experiments/linearly_separable.jl @@ -3,4 +3,7 @@ counterfactual_data, test_data = train_test_split( load_blobs(n_obs; cluster_std=0.1, center_box=(-1.0 => 1.0)); test_size=TEST_SIZE ) -run_experiment(counterfactual_data, test_data; dataname="Linearly Separable") \ No newline at end of file +run_experiment( + counterfactual_data, test_data; + dataname="Linearly Separable" +) \ No newline at end of file diff --git a/experiments/mnist.jl b/experiments/mnist.jl index 39e13d406a24c711eeb057292901d562119e4bd9..557143e98e5f707a8d495bb6cffdeef9593a19ee 100644 --- a/experiments/mnist.jl +++ b/experiments/mnist.jl @@ -52,7 +52,7 @@ run_experiment( ð’Ÿx = Uniform(-1.0, 1.0), α = [1.0,1.0,1e-2], sampling_batch_size = 10, - ssampling_steps=25, + sampling_steps=25, use_ensembling = true, generators = generator_dict, ) \ No newline at end of file diff --git a/experiments/models/models.jl b/experiments/models/models.jl index 16bffaeb31583394abea0deabc9d24cf07523a91..472e1c16af56fd3dfbd42bdd5fcde365387ac6d7 100644 --- a/experiments/models/models.jl +++ b/experiments/models/models.jl @@ -17,7 +17,8 @@ function prepare_models(exp::Experiment) models = default_models(; sampler=sampler, builder=builder, - batch_size=batch_size(exp) + batch_size=batch_size(exp), + sampling_steps=exp.sampling_steps, ) end @info "Training models." diff --git a/experiments/moons.jl b/experiments/moons.jl index bb660842de036f330999580e7d3a0aa384a81be8..0203438024ecca4b7745bc3125e0407f442fa01c 100644 --- a/experiments/moons.jl +++ b/experiments/moons.jl @@ -1,16 +1,13 @@ n_obs = Int(2500 / (1.0 - TEST_SIZE)) counterfactual_data, test_data = train_test_split(load_moons(n_obs); test_size=TEST_SIZE) run_experiment( - counterfactual_data, test_data; dataname="Moons", + counterfactual_data, test_data; + dataname="Moons", epochs=500, n_hidden=32, activation = Flux.relu, - α=[1.0, 1.0, 1e-1], sampling_batch_size=10, sampling_steps=30, - λâ‚=0.25, - λ₂=0.75, - λ₃=0.75, + Λ=[0.25, 0.75, 0.75], opt=Flux.Optimise.Descent(0.05), - use_class_loss=false ) \ No newline at end of file