Skip to content
Snippets Groups Projects
Commit 048e974e authored by Pat Alt's avatar Pat Alt
Browse files

slowly slowly

parent 5ffce35d
No related branches found
No related tags found
1 merge request!7669 initial run including fmnist lenet and new method
counterfactual_data, test_data = train_test_split(load_gmsc(nothing); test_size=TEST_SIZE)
# Default builder:
n_hidden = 128
activation = Flux.swish
builder = MLJFlux.@builder Flux.Chain(
Dense(n_in, n_hidden, activation),
Dense(n_hidden, n_hidden, activation),
Dense(n_hidden, n_out),
)
run_experiment(
counterfactual_data, test_data;
dataname="GMSC",
n_hidden=128,
activation = Flux.swish,
builder = MLJFlux.@builder Flux.Chain(
Dense(n_in, n_hidden, activation),
Dense(n_hidden, n_hidden, activation),
Dense(n_hidden, n_out),
),
builder = builder,
α=[1.0, 1.0, 1e-1],
sampling_batch_size=10,
sampling_steps = 30,
......
......@@ -22,16 +22,19 @@ add_models = Dict(
:lenet5 => lenet5,
)
# Default builder:
n_hidden = 128
activation = Flux.swish
builder = MLJFlux.@builder Flux.Chain(
Dense(n_in, n_hidden, activation),
Dense(n_hidden, n_out),
)
# Run:
run_experiment(
counterfactual_data, test_data;
dataname="MNIST",
n_hidden = 128,
activation = Flux.swish,
builder= MLJFlux.@builder Flux.Chain(
Dense(n_in, n_hidden, activation),
Dense(n_hidden, n_out),
),
builder= builder,
𝒟x = Uniform(-1.0, 1.0),
α = [1.0,1.0,1e-2],
sampling_batch_size = 10,
......@@ -42,6 +45,6 @@ run_experiment(
nmin = 10,
use_variants = false,
use_class_loss = true,
add_models = add_models,
additional_models=add_models,
epochs = 10,
)
\ No newline at end of file
......@@ -11,6 +11,8 @@ function prepare_models(exp::Experiment)
if !exp.use_pretrained
if isnothing(exp.builder)
builder = default_builder()
else
builder = exp.builder
end
# Default models:
if isnothing(exp.models)
......@@ -25,6 +27,7 @@ function prepare_models(exp::Experiment)
use_ensembling=exp.use_ensembling,
finaliser=exp.finaliser,
loss=exp.loss,
epochs=exp.epochs,
)
end
# Additional models:
......@@ -36,10 +39,10 @@ function prepare_models(exp::Experiment)
batch_size=batch_size(exp),
finaliser=exp.finaliser,
loss=exp.loss,
epochs=exp.epochs,
)
end
add_models = Dict(k => mod(;batch_size=batch_size(exp), ) for (k, mod) in exp.additional_models)
models = merge!(models, exp.additional_models)
models = merge(models, add_models)
end
@info "Training models."
model_dict = train_models(models, X, labels; cov=exp.coverage)
......
......@@ -12,7 +12,7 @@ using CounterfactualExplanations.Parallelization
using CSV
using Dates
using DataFrames
using Distributions: Normal, Distribution, Categorical
using Distributions: Normal, Distribution, Categorical, Uniform
using ECCCo
using Flux
using JointEnergyModels
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment