diff --git a/experiments/gmsc.jl b/experiments/gmsc.jl index f8ecaa4d60d29b3c1ed5511e6956598094ea9090..f9c8120ef8f4d74dcd85301b3e23c4143eeda81d 100644 --- a/experiments/gmsc.jl +++ b/experiments/gmsc.jl @@ -1,14 +1,18 @@ counterfactual_data, test_data = train_test_split(load_gmsc(nothing); test_size=TEST_SIZE) + +# Default builder: +n_hidden = 128 +activation = Flux.swish +builder = MLJFlux.@builder Flux.Chain( + Dense(n_in, n_hidden, activation), + Dense(n_hidden, n_hidden, activation), + Dense(n_hidden, n_out), +) + run_experiment( counterfactual_data, test_data; dataname="GMSC", - n_hidden=128, - activation = Flux.swish, - builder = MLJFlux.@builder Flux.Chain( - Dense(n_in, n_hidden, activation), - Dense(n_hidden, n_hidden, activation), - Dense(n_hidden, n_out), - ), + builder = builder, α=[1.0, 1.0, 1e-1], sampling_batch_size=10, sampling_steps = 30, diff --git a/experiments/mnist.jl b/experiments/mnist.jl index 21cb7b7924ccabb210dcafd3eb6dde6d9bdb9761..372552e28fb5ab9cb94427177b08eace3e73aefa 100644 --- a/experiments/mnist.jl +++ b/experiments/mnist.jl @@ -22,16 +22,19 @@ add_models = Dict( :lenet5 => lenet5, ) +# Default builder: +n_hidden = 128 +activation = Flux.swish +builder = MLJFlux.@builder Flux.Chain( + Dense(n_in, n_hidden, activation), + Dense(n_hidden, n_out), +) + # Run: run_experiment( counterfactual_data, test_data; dataname="MNIST", - n_hidden = 128, - activation = Flux.swish, - builder= MLJFlux.@builder Flux.Chain( - Dense(n_in, n_hidden, activation), - Dense(n_hidden, n_out), - ), + builder= builder, ð’Ÿx = Uniform(-1.0, 1.0), α = [1.0,1.0,1e-2], sampling_batch_size = 10, @@ -42,6 +45,6 @@ run_experiment( nmin = 10, use_variants = false, use_class_loss = true, - add_models = add_models, + additional_models=add_models, epochs = 10, ) \ No newline at end of file diff --git a/experiments/models/models.jl b/experiments/models/models.jl index 89975bcd1c6800b606f9e069e80061af5f5ff96b..bd2d5026cfb1dc1173b8091e3638dc89ab7f010b 100644 --- a/experiments/models/models.jl +++ b/experiments/models/models.jl @@ -11,6 +11,8 @@ function prepare_models(exp::Experiment) if !exp.use_pretrained if isnothing(exp.builder) builder = default_builder() + else + builder = exp.builder end # Default models: if isnothing(exp.models) @@ -25,6 +27,7 @@ function prepare_models(exp::Experiment) use_ensembling=exp.use_ensembling, finaliser=exp.finaliser, loss=exp.loss, + epochs=exp.epochs, ) end # Additional models: @@ -36,10 +39,10 @@ function prepare_models(exp::Experiment) batch_size=batch_size(exp), finaliser=exp.finaliser, loss=exp.loss, + epochs=exp.epochs, ) end - add_models = Dict(k => mod(;batch_size=batch_size(exp), ) for (k, mod) in exp.additional_models) - models = merge!(models, exp.additional_models) + models = merge(models, add_models) end @info "Training models." model_dict = train_models(models, X, labels; cov=exp.coverage) diff --git a/experiments/setup_env.jl b/experiments/setup_env.jl index 1d6a3be32e3c5ceaa6074ab12b8c004712488bbb..c6657bdd40e2f0e0280b901cf6b3a91d662e93dc 100644 --- a/experiments/setup_env.jl +++ b/experiments/setup_env.jl @@ -12,7 +12,7 @@ using CounterfactualExplanations.Parallelization using CSV using Dates using DataFrames -using Distributions: Normal, Distribution, Categorical +using Distributions: Normal, Distribution, Categorical, Uniform using ECCCo using Flux using JointEnergyModels