From 87020517e326417f88a93692c266b53136fd3505 Mon Sep 17 00:00:00 2001 From: Pat Alt <55311242+pat-alt@users.noreply.github.com> Date: Fri, 4 Aug 2023 15:27:59 +0200 Subject: [PATCH] additional models --- notebooks/mnist.qmd | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/notebooks/mnist.qmd b/notebooks/mnist.qmd index b39cee0f..32f19287 100644 --- a/notebooks/mnist.qmd +++ b/notebooks/mnist.qmd @@ -433,6 +433,35 @@ display(plt) savefig(plt, joinpath(output_images_path, "mnist_eccco.png")) ``` +#### Additional Models (not in paper) + +```{julia} +add_retrain = true + +# Deep Ensemble: +mlp_large_ens = EnsembleModel(model=mlp, n=50) + +add_models = Dict( + "Large MLP Ensemble" => mlp_large_ens, +) + +if add_retrain + add_model_dict = Dict(mod_name => _train(mod; mod_name=mod_name) for (mod_name, mod) in add_models) + large_model_dict = merge(model_dict, add_model_dict) + Serialization.serialize(joinpath(output_path,"mnist_models_large.jls"), large_model_dict) +else + large_model_dict = Serialization.deserialize(joinpath(output_path,"mnist_models_large.jls")) +end +``` + +```{julia} +plt_additional_models, _, _ces_ = _plot_eccco_mnist( + plt_order = ["MLP", "MLP Ensemble", "Large Ensemble" "JEM", "JEM Ensemble"] +) +display(plt_additional_models) +savefig(plt, joinpath(output_images_path, "mnist_eccco_additional.png")) +``` + ### All digits ```{julia} -- GitLab