From 5ffce35db03c9b1594649237881295b8cb8a38eb Mon Sep 17 00:00:00 2001 From: Pat Alt <55311242+pat-alt@users.noreply.github.com> Date: Tue, 5 Sep 2023 12:04:07 +0200 Subject: [PATCH] ready to do first full run --- experiments/mnist.jl | 3 ++- experiments/models/train_models.jl | 14 +------------- 2 files changed, 3 insertions(+), 14 deletions(-) diff --git a/experiments/mnist.jl b/experiments/mnist.jl index 5e1e9475..21cb7b79 100644 --- a/experiments/mnist.jl +++ b/experiments/mnist.jl @@ -35,7 +35,7 @@ run_experiment( ð’Ÿx = Uniform(-1.0, 1.0), α = [1.0,1.0,1e-2], sampling_batch_size = 10, - sampling_steps=25, + sampling_steps=50, use_ensembling = true, n_individuals = 5, nsamples = 10, @@ -43,4 +43,5 @@ run_experiment( use_variants = false, use_class_loss = true, add_models = add_models, + epochs = 10, ) \ No newline at end of file diff --git a/experiments/models/train_models.jl b/experiments/models/train_models.jl index 5f3010b0..863584a4 100644 --- a/experiments/models/train_models.jl +++ b/experiments/models/train_models.jl @@ -4,19 +4,7 @@ Trains all models in a dictionary and returns a dictionary of `ConformalModel` objects. """ function train_models(models::Dict, X, y; kwargs...) - if USE_THREADS - model_dicts = [Dict{Any,Any}() for i in 1:Threads.nthreads()] - mod_names = collect(keys(models)) - mod_values = collect(values(models)) - Threads.@threads for i in eachindex(mod_names) - mod_name = mod_names[i] - model = mod_values[i] - model_dicts[Threads.threadid()][mod_name] = _train(model, X, y; mod_name=mod_name, kwargs...) - end - model_dict = reduce(merge, model_dicts) - else - model_dict = Dict(mod_name => _train(model, X, y; mod_name=mod_name, kwargs...) for (mod_name, model) in models) - end + model_dict = Dict(mod_name => _train(model, X, y; mod_name=mod_name, kwargs...) for (mod_name, model) in models) return model_dict end -- GitLab