Skip to content
Snippets Groups Projects
Commit fdcb3290 authored by Pat Alt's avatar Pat Alt
Browse files

trying this a different way

parent 9b2df179
No related branches found
No related tags found
1 merge request!7669 initial run including fmnist lenet and new method
......@@ -70,33 +70,17 @@ function run_benchmark(exper::Experiment, model_dict::Dict)
end
# Run benchmark:
bmks = []
labels = counterfactual_data.output_encoder.labels
for target in sort(unique(labels))
for factual in sort(unique(labels))
if factual == target
continue
end
@info "Benchmarking factual=$(factual) ▶️ target=$(target)."
bmk = benchmark(
counterfactual_data;
models=model_dict,
generators=generator_dict,
measure=measures,
suppress_training=true, dataname=dataname,
n_individuals=n_individuals,
target=target, factual=factual,
initialization=:identity,
converge_when=:generator_conditions,
parallelizer=parallelizer,
)
if is_multi_processed(exper)
MPI.Barrier(parallelizer.comm)
end
push!(bmks, bmk)
end
end
final_bmk = reduce(vcat, bmks)
return final_bmk, generator_dict
bmk = benchmark(
counterfactual_data;
models=model_dict,
generators=generator_dict,
measure=measures,
suppress_training=true, dataname=dataname,
n_individuals=n_individuals,
initialization=:identity,
converge_when=:generator_conditions,
parallelizer=parallelizer
)
return bmk, generator_dict
end
......@@ -25,7 +25,7 @@ builder = MLJFlux.@builder Flux.Chain(
)
# Number of individuals:
n_ind = N_IND_SPECIFIED ? N_IND : 5
n_ind = N_IND_SPECIFIED ? N_IND : 10
# Run:
run_experiment(
......@@ -38,10 +38,11 @@ run_experiment(
sampling_steps=50,
use_ensembling = true,
n_individuals = n_ind,
nsamples = 10,
nmin = 10,
use_variants = false,
use_class_loss = true,
additional_models=add_models,
epochs = 10,
nsamples=10,
nmin=1,
niter_eccco=100
)
\ No newline at end of file
......@@ -11,4 +11,4 @@
module load 2023r1 openmpi
srun julia --project=experiments experiments/run_experiments.jl -- data=mnist output_path=results mpi retrain > experiments/mnist.log
srun julia --project=experiments experiments/run_experiments.jl -- data=mnist output_path=results mpi > experiments/mnist.log
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment