From 87020517e326417f88a93692c266b53136fd3505 Mon Sep 17 00:00:00 2001
From: Pat Alt <55311242+pat-alt@users.noreply.github.com>
Date: Fri, 4 Aug 2023 15:27:59 +0200
Subject: [PATCH] additional models

---
 notebooks/mnist.qmd | 29 +++++++++++++++++++++++++++++
 1 file changed, 29 insertions(+)

diff --git a/notebooks/mnist.qmd b/notebooks/mnist.qmd
index b39cee0f..32f19287 100644
--- a/notebooks/mnist.qmd
+++ b/notebooks/mnist.qmd
@@ -433,6 +433,35 @@ display(plt)
 savefig(plt, joinpath(output_images_path, "mnist_eccco.png"))
 ```
 
+#### Additional Models (not in paper)
+
+```{julia}
+add_retrain = true
+
+# Deep Ensemble:
+mlp_large_ens = EnsembleModel(model=mlp, n=50)
+
+add_models = Dict(
+    "Large MLP Ensemble" => mlp_large_ens,
+)
+
+if add_retrain
+    add_model_dict = Dict(mod_name => _train(mod; mod_name=mod_name) for (mod_name, mod) in add_models)
+    large_model_dict = merge(model_dict, add_model_dict)
+    Serialization.serialize(joinpath(output_path,"mnist_models_large.jls"), large_model_dict)
+else
+    large_model_dict = Serialization.deserialize(joinpath(output_path,"mnist_models_large.jls"))
+end
+```
+
+```{julia}
+plt_additional_models, _, _ces_ = _plot_eccco_mnist(
+    plt_order = ["MLP", "MLP Ensemble", "Large Ensemble" "JEM", "JEM Ensemble"]
+)
+display(plt_additional_models)
+savefig(plt, joinpath(output_images_path, "mnist_eccco_additional.png"))
+```
+
 ### All digits
 
 ```{julia}
-- 
GitLab