diff --git a/experiments/circles.jl b/experiments/circles.jl index 1bfe1ad5fe13945d72a57558aed5da5c16cfa862..ca50439846d0b66437e7f172e2700e9529013a06 100644 --- a/experiments/circles.jl +++ b/experiments/circles.jl @@ -10,15 +10,15 @@ model_tuning_params = DEFAULT_MODEL_TUNING_SMALL tuning_params = DEFAULT_GENERATOR_TUNING # Parameter choices: +# These are the parameter choices originally used in the paper that were manually fine-tuned for the JEM. params = ( + use_tuned=false, n_hidden=32, - activation = Flux.swish, - sampling_steps=20, - opt=Flux.Optimise.Descent(0.01), + n_layers=3, + activation=Flux.swish, + epochs=100, α=[1.0, 1.0, 1e-2], - nsamples=100, - niter_eccco=100, - Λ=[0.1, 0.2, 0.2], + sampling_steps=30, ) if !GRID_SEARCH @@ -31,6 +31,7 @@ else grid_search( counterfactual_data, test_data; dataname=dataname, - tuning_params=tuning_params + tuning_params=tuning_params, + params... ) end \ No newline at end of file diff --git a/experiments/experiment.jl b/experiments/experiment.jl index 74755b7b9927e16ea2903fa38704efb9a5a4cb2f..3e4180801fbe2f1377b8f8d7df3eedebceb1e7cb 100644 --- a/experiments/experiment.jl +++ b/experiments/experiment.jl @@ -9,14 +9,15 @@ Base.@kwdef struct Experiment use_pretrained::Bool = !RETRAIN models::Union{Nothing,Dict} = nothing additional_models::Union{Nothing,Dict} = nothing - builder::Union{Nothing,MLJFlux.Builder} = nothing ð’Ÿx::Distribution = Normal() sampling_batch_size::Int = 50 sampling_steps::Int = 50 min_batch_size::Int = 128 epochs::Int = 100 n_hidden::Int = 32 + n_layers::Int = 3 activation::Function = Flux.relu + builder::Union{Nothing,MLJFlux.Builder} = default_builder(n_hidden=n_hidden, n_layers=n_layers, activation=activation) α::AbstractArray = [1.0, 1.0, 1e-1] n_ens::Int = 5 use_ensembling::Bool = true @@ -39,6 +40,7 @@ Base.@kwdef struct Experiment reg_strength::Real = 0.1 niter_eccco::Union{Nothing,Int} = nothing model_tuning_params::NamedTuple = DEFAULT_MODEL_TUNING_SMALL + use_tuned::Bool = true end "A container to hold the results of an experiment." diff --git a/experiments/linearly_separable.jl b/experiments/linearly_separable.jl index 0c98beb025b63ff776651258cc41f831dee64cf3..cfb02a7807b4982dfdd14869dfae0fca22a93971 100644 --- a/experiments/linearly_separable.jl +++ b/experiments/linearly_separable.jl @@ -13,16 +13,20 @@ model_tuning_params = DEFAULT_MODEL_TUNING_SMALL tuning_params = DEFAULT_GENERATOR_TUNING # Parameter choices: +# These are the parameter choices originally used in the paper that were manually fine-tuned for the JEM. params = ( - nsamples=100, - niter_eccco=100, - Λ=[0.1, 0.2, 0.2], + use_tuned=false, + n_hidden=16, + n_layers=3, + activation=Flux.swish, + epochs=100, ) if !GRID_SEARCH run_experiment( counterfactual_data, test_data; dataname=dataname, + model_tuning_params=model_tuning_params, params... ) else @@ -30,5 +34,6 @@ else counterfactual_data, test_data; dataname=dataname, tuning_params=tuning_params, + params... ) end \ No newline at end of file diff --git a/experiments/models/models.jl b/experiments/models/models.jl index 597205dbf98edce2950ac33255d3119c41ebbfa8..d063ca77ad216469a5edd8289e5e1e2989ac424c 100644 --- a/experiments/models/models.jl +++ b/experiments/models/models.jl @@ -10,14 +10,14 @@ function prepare_models(exper::Experiment; save_models::Bool=true) # Training: if !exper.use_pretrained if isnothing(exper.builder) - if tuned_mlp_exists(exper) + if tuned_mlp_exists(exper) && exper.use_tuned @info "Loading tuned model architecture." # Load the best MLP: best_mlp = Serialization.deserialize(joinpath(tuned_model_path(exper), "$(exper.save_name)_best_mlp.jls")) builder = best_mlp.best_model.builder else # Otherwise, use default MLP: - builder = default_builder() + builder = default_builder(n_hidden=exper.n_hidden, n_layers=exper.n_layers, activation=exper.activation) end else builder = exper.builder diff --git a/experiments/moons.jl b/experiments/moons.jl index ef700ac810c54e5ec93e9c2e8898d67aae3fd10c..6d7d285107a50de334495df5934536c97e367a24 100644 --- a/experiments/moons.jl +++ b/experiments/moons.jl @@ -10,17 +10,15 @@ model_tuning_params = DEFAULT_MODEL_TUNING_SMALL tuning_params = DEFAULT_GENERATOR_TUNING # Parameter choices: +# These are the parameter choices originally used in the paper that were manually fine-tuned for the JEM. params = ( - epochs=500, + use_tuned=false, n_hidden=32, - activation = Flux.relu, + n_layers=3, + activation=Flux.relu, + epochs=500, sampling_batch_size=10, sampling_steps=30, - opt=Flux.Optimise.Descent(0.05), - α=[1.0, 1.0, 1e-1], - nsamples=100, - niter_eccco=100, - Λ=[0.1, 0.2, 0.2], ) if !GRID_SEARCH @@ -33,6 +31,7 @@ else grid_search( counterfactual_data, test_data; dataname=dataname, - tuning_params=tuning_params + tuning_params=tuning_params, + params... ) end \ No newline at end of file