Skip to content
Snippets Groups Projects
Commit d8d746a4 authored by Pat Alt's avatar Pat Alt
Browse files

n_individuals

parent 7901661b
No related branches found
No related tags found
1 merge request!7669 initial run including fmnist lenet and new method
......@@ -22,7 +22,7 @@ Base.@kwdef struct Experiment
use_ensembling::Bool = true
coverage::Float64 = DEFAULT_COVERAGE
generators::Union{Nothing,Dict} = nothing
n_individuals::Int = 25
n_individuals::Int = N_IND
ce_measures::AbstractArray = CE_MEASURES
model_measures::Dict = MODEL_MEASURES
use_class_loss::Bool = false
......
......@@ -24,6 +24,9 @@ builder = MLJFlux.@builder Flux.Chain(
Dense(n_hidden, n_out),
)
# Number of individuals:
n_ind = N_IND_SPECIFIED ? N_IND : 5
# Run:
run_experiment(
counterfactual_data, test_data;
......@@ -34,7 +37,7 @@ run_experiment(
sampling_batch_size = 10,
sampling_steps=50,
use_ensembling = true,
n_individuals = 5,
n_individuals = n_ind,
nsamples = 10,
nmin = 10,
use_variants = false,
......
......@@ -9,6 +9,9 @@ builder = MLJFlux.@builder Flux.Chain(
Dense(n_hidden, n_out),
)
# Number of individuals:
n_ind = N_IND_SPECIFIED ? N_IND : 10
run_experiment(
counterfactual_data, test_data;
dataname="GMSC",
......@@ -19,6 +22,6 @@ run_experiment(
use_ensembling = true,
Λ=[0.1, 0.5, 0.5],
opt = Flux.Optimise.Descent(0.05),
n_individuals = 10,
n_individuals = n_ind,
use_variants = false,
)
\ No newline at end of file
......@@ -24,6 +24,9 @@ builder = MLJFlux.@builder Flux.Chain(
Dense(n_hidden, n_out),
)
# Number of individuals:
n_ind = N_IND_SPECIFIED ? N_IND : 5
# Run:
run_experiment(
counterfactual_data, test_data;
......@@ -34,7 +37,7 @@ run_experiment(
sampling_batch_size = 10,
sampling_steps=50,
use_ensembling = true,
n_individuals = 5,
n_individuals = n_ind,
nsamples = 10,
nmin = 10,
use_variants = false,
......
......@@ -117,4 +117,18 @@ const CE_MEASURES = [
"Test set proportion."
const TEST_SIZE = 0.2
const UPLOAD = "upload" ARGS
\ No newline at end of file
const UPLOAD = "upload" ARGS
n_ind_specified = false
if any(contains.(ARGS, "n_individuals="))
n_ind_specified = true
n_individuals = ARGS[findall(contains.(ARGS, "n_individuals="))][1] |> x -> replace(x, "n_individuals=" => "")
else
n_individuals = 25
end
"Number of individuals to use in benchmarking."
const N_IND = n_individuals
"Boolean flag to check if number of individuals was specified."
const N_IND_SPECIFIED = n_ind_specified
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment