Skip to content
Snippets Groups Projects
Commit 80d17217 authored by Pat Alt's avatar Pat Alt
Browse files

going back to default models for synthetic

parent e3bfed0c
No related branches found
No related tags found
1 merge request!7669 initial run including fmnist lenet and new method
......@@ -10,15 +10,15 @@ model_tuning_params = DEFAULT_MODEL_TUNING_SMALL
tuning_params = DEFAULT_GENERATOR_TUNING
# Parameter choices:
# These are the parameter choices originally used in the paper that were manually fine-tuned for the JEM.
params = (
use_tuned=false,
n_hidden=32,
activation = Flux.swish,
sampling_steps=20,
opt=Flux.Optimise.Descent(0.01),
n_layers=3,
activation=Flux.swish,
epochs=100,
α=[1.0, 1.0, 1e-2],
nsamples=100,
niter_eccco=100,
Λ=[0.1, 0.2, 0.2],
sampling_steps=30,
)
if !GRID_SEARCH
......@@ -31,6 +31,7 @@ else
grid_search(
counterfactual_data, test_data;
dataname=dataname,
tuning_params=tuning_params
tuning_params=tuning_params,
params...
)
end
\ No newline at end of file
......@@ -9,14 +9,15 @@ Base.@kwdef struct Experiment
use_pretrained::Bool = !RETRAIN
models::Union{Nothing,Dict} = nothing
additional_models::Union{Nothing,Dict} = nothing
builder::Union{Nothing,MLJFlux.Builder} = nothing
𝒟x::Distribution = Normal()
sampling_batch_size::Int = 50
sampling_steps::Int = 50
min_batch_size::Int = 128
epochs::Int = 100
n_hidden::Int = 32
n_layers::Int = 3
activation::Function = Flux.relu
builder::Union{Nothing,MLJFlux.Builder} = default_builder(n_hidden=n_hidden, n_layers=n_layers, activation=activation)
α::AbstractArray = [1.0, 1.0, 1e-1]
n_ens::Int = 5
use_ensembling::Bool = true
......@@ -39,6 +40,7 @@ Base.@kwdef struct Experiment
reg_strength::Real = 0.1
niter_eccco::Union{Nothing,Int} = nothing
model_tuning_params::NamedTuple = DEFAULT_MODEL_TUNING_SMALL
use_tuned::Bool = true
end
"A container to hold the results of an experiment."
......
......@@ -13,16 +13,20 @@ model_tuning_params = DEFAULT_MODEL_TUNING_SMALL
tuning_params = DEFAULT_GENERATOR_TUNING
# Parameter choices:
# These are the parameter choices originally used in the paper that were manually fine-tuned for the JEM.
params = (
nsamples=100,
niter_eccco=100,
Λ=[0.1, 0.2, 0.2],
use_tuned=false,
n_hidden=16,
n_layers=3,
activation=Flux.swish,
epochs=100,
)
if !GRID_SEARCH
run_experiment(
counterfactual_data, test_data;
dataname=dataname,
model_tuning_params=model_tuning_params,
params...
)
else
......@@ -30,5 +34,6 @@ else
counterfactual_data, test_data;
dataname=dataname,
tuning_params=tuning_params,
params...
)
end
\ No newline at end of file
......@@ -10,14 +10,14 @@ function prepare_models(exper::Experiment; save_models::Bool=true)
# Training:
if !exper.use_pretrained
if isnothing(exper.builder)
if tuned_mlp_exists(exper)
if tuned_mlp_exists(exper) && exper.use_tuned
@info "Loading tuned model architecture."
# Load the best MLP:
best_mlp = Serialization.deserialize(joinpath(tuned_model_path(exper), "$(exper.save_name)_best_mlp.jls"))
builder = best_mlp.best_model.builder
else
# Otherwise, use default MLP:
builder = default_builder()
builder = default_builder(n_hidden=exper.n_hidden, n_layers=exper.n_layers, activation=exper.activation)
end
else
builder = exper.builder
......
......@@ -10,17 +10,15 @@ model_tuning_params = DEFAULT_MODEL_TUNING_SMALL
tuning_params = DEFAULT_GENERATOR_TUNING
# Parameter choices:
# These are the parameter choices originally used in the paper that were manually fine-tuned for the JEM.
params = (
epochs=500,
use_tuned=false,
n_hidden=32,
activation = Flux.relu,
n_layers=3,
activation=Flux.relu,
epochs=500,
sampling_batch_size=10,
sampling_steps=30,
opt=Flux.Optimise.Descent(0.05),
α=[1.0, 1.0, 1e-1],
nsamples=100,
niter_eccco=100,
Λ=[0.1, 0.2, 0.2],
)
if !GRID_SEARCH
......@@ -33,6 +31,7 @@ else
grid_search(
counterfactual_data, test_data;
dataname=dataname,
tuning_params=tuning_params
tuning_params=tuning_params,
params...
)
end
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment