From 5ffce35db03c9b1594649237881295b8cb8a38eb Mon Sep 17 00:00:00 2001
From: Pat Alt <55311242+pat-alt@users.noreply.github.com>
Date: Tue, 5 Sep 2023 12:04:07 +0200
Subject: [PATCH] ready to do first full run

---
 experiments/mnist.jl               |  3 ++-
 experiments/models/train_models.jl | 14 +-------------
 2 files changed, 3 insertions(+), 14 deletions(-)

diff --git a/experiments/mnist.jl b/experiments/mnist.jl
index 5e1e9475..21cb7b79 100644
--- a/experiments/mnist.jl
+++ b/experiments/mnist.jl
@@ -35,7 +35,7 @@ run_experiment(
     𝒟x = Uniform(-1.0, 1.0),
     α = [1.0,1.0,1e-2],
     sampling_batch_size = 10,
-    sampling_steps=25,
+    sampling_steps=50,
     use_ensembling = true,
     n_individuals = 5,
     nsamples = 10,
@@ -43,4 +43,5 @@ run_experiment(
     use_variants = false,
     use_class_loss = true,
     add_models = add_models,
+    epochs = 10,
 )
\ No newline at end of file
diff --git a/experiments/models/train_models.jl b/experiments/models/train_models.jl
index 5f3010b0..863584a4 100644
--- a/experiments/models/train_models.jl
+++ b/experiments/models/train_models.jl
@@ -4,19 +4,7 @@
 Trains all models in a dictionary and returns a dictionary of `ConformalModel` objects.
 """
 function train_models(models::Dict, X, y; kwargs...)
-    if USE_THREADS
-        model_dicts = [Dict{Any,Any}() for i in 1:Threads.nthreads()] 
-        mod_names = collect(keys(models))
-        mod_values = collect(values(models))
-        Threads.@threads for i in eachindex(mod_names)
-            mod_name = mod_names[i]
-            model = mod_values[i]
-            model_dicts[Threads.threadid()][mod_name] = _train(model, X, y; mod_name=mod_name, kwargs...)
-        end
-        model_dict = reduce(merge, model_dicts)
-    else
-        model_dict = Dict(mod_name => _train(model, X, y; mod_name=mod_name, kwargs...) for (mod_name, model) in models)
-    end
+    model_dict = Dict(mod_name => _train(model, X, y; mod_name=mod_name, kwargs...) for (mod_name, model) in models)
     return model_dict
 end
 
-- 
GitLab