Skip to content
Snippets Groups Projects
Commit 048e974e authored by Pat Alt's avatar Pat Alt
Browse files

slowly slowly

parent 5ffce35d
No related branches found
No related tags found
1 merge request!7669 initial run including fmnist lenet and new method
counterfactual_data, test_data = train_test_split(load_gmsc(nothing); test_size=TEST_SIZE) counterfactual_data, test_data = train_test_split(load_gmsc(nothing); test_size=TEST_SIZE)
# Default builder:
n_hidden = 128
activation = Flux.swish
builder = MLJFlux.@builder Flux.Chain(
Dense(n_in, n_hidden, activation),
Dense(n_hidden, n_hidden, activation),
Dense(n_hidden, n_out),
)
run_experiment( run_experiment(
counterfactual_data, test_data; counterfactual_data, test_data;
dataname="GMSC", dataname="GMSC",
n_hidden=128, builder = builder,
activation = Flux.swish,
builder = MLJFlux.@builder Flux.Chain(
Dense(n_in, n_hidden, activation),
Dense(n_hidden, n_hidden, activation),
Dense(n_hidden, n_out),
),
α=[1.0, 1.0, 1e-1], α=[1.0, 1.0, 1e-1],
sampling_batch_size=10, sampling_batch_size=10,
sampling_steps = 30, sampling_steps = 30,
......
...@@ -22,16 +22,19 @@ add_models = Dict( ...@@ -22,16 +22,19 @@ add_models = Dict(
:lenet5 => lenet5, :lenet5 => lenet5,
) )
# Default builder:
n_hidden = 128
activation = Flux.swish
builder = MLJFlux.@builder Flux.Chain(
Dense(n_in, n_hidden, activation),
Dense(n_hidden, n_out),
)
# Run: # Run:
run_experiment( run_experiment(
counterfactual_data, test_data; counterfactual_data, test_data;
dataname="MNIST", dataname="MNIST",
n_hidden = 128, builder= builder,
activation = Flux.swish,
builder= MLJFlux.@builder Flux.Chain(
Dense(n_in, n_hidden, activation),
Dense(n_hidden, n_out),
),
𝒟x = Uniform(-1.0, 1.0), 𝒟x = Uniform(-1.0, 1.0),
α = [1.0,1.0,1e-2], α = [1.0,1.0,1e-2],
sampling_batch_size = 10, sampling_batch_size = 10,
...@@ -42,6 +45,6 @@ run_experiment( ...@@ -42,6 +45,6 @@ run_experiment(
nmin = 10, nmin = 10,
use_variants = false, use_variants = false,
use_class_loss = true, use_class_loss = true,
add_models = add_models, additional_models=add_models,
epochs = 10, epochs = 10,
) )
\ No newline at end of file
...@@ -11,6 +11,8 @@ function prepare_models(exp::Experiment) ...@@ -11,6 +11,8 @@ function prepare_models(exp::Experiment)
if !exp.use_pretrained if !exp.use_pretrained
if isnothing(exp.builder) if isnothing(exp.builder)
builder = default_builder() builder = default_builder()
else
builder = exp.builder
end end
# Default models: # Default models:
if isnothing(exp.models) if isnothing(exp.models)
...@@ -25,6 +27,7 @@ function prepare_models(exp::Experiment) ...@@ -25,6 +27,7 @@ function prepare_models(exp::Experiment)
use_ensembling=exp.use_ensembling, use_ensembling=exp.use_ensembling,
finaliser=exp.finaliser, finaliser=exp.finaliser,
loss=exp.loss, loss=exp.loss,
epochs=exp.epochs,
) )
end end
# Additional models: # Additional models:
...@@ -36,10 +39,10 @@ function prepare_models(exp::Experiment) ...@@ -36,10 +39,10 @@ function prepare_models(exp::Experiment)
batch_size=batch_size(exp), batch_size=batch_size(exp),
finaliser=exp.finaliser, finaliser=exp.finaliser,
loss=exp.loss, loss=exp.loss,
epochs=exp.epochs,
) )
end end
add_models = Dict(k => mod(;batch_size=batch_size(exp), ) for (k, mod) in exp.additional_models) models = merge(models, add_models)
models = merge!(models, exp.additional_models)
end end
@info "Training models." @info "Training models."
model_dict = train_models(models, X, labels; cov=exp.coverage) model_dict = train_models(models, X, labels; cov=exp.coverage)
......
...@@ -12,7 +12,7 @@ using CounterfactualExplanations.Parallelization ...@@ -12,7 +12,7 @@ using CounterfactualExplanations.Parallelization
using CSV using CSV
using Dates using Dates
using DataFrames using DataFrames
using Distributions: Normal, Distribution, Categorical using Distributions: Normal, Distribution, Categorical, Uniform
using ECCCo using ECCCo
using Flux using Flux
using JointEnergyModels using JointEnergyModels
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment