diff --git a/notebooks/mnist.qmd b/notebooks/mnist.qmd index b39cee0ff557e2d0b9aab9557d7e316650ae5f14..32f19287e083c96229564242d16d76d3f7fe0652 100644 --- a/notebooks/mnist.qmd +++ b/notebooks/mnist.qmd @@ -433,6 +433,35 @@ display(plt) savefig(plt, joinpath(output_images_path, "mnist_eccco.png")) ``` +#### Additional Models (not in paper) + +```{julia} +add_retrain = true + +# Deep Ensemble: +mlp_large_ens = EnsembleModel(model=mlp, n=50) + +add_models = Dict( + "Large MLP Ensemble" => mlp_large_ens, +) + +if add_retrain + add_model_dict = Dict(mod_name => _train(mod; mod_name=mod_name) for (mod_name, mod) in add_models) + large_model_dict = merge(model_dict, add_model_dict) + Serialization.serialize(joinpath(output_path,"mnist_models_large.jls"), large_model_dict) +else + large_model_dict = Serialization.deserialize(joinpath(output_path,"mnist_models_large.jls")) +end +``` + +```{julia} +plt_additional_models, _, _ces_ = _plot_eccco_mnist( + plt_order = ["MLP", "MLP Ensemble", "Large Ensemble" "JEM", "JEM Ensemble"] +) +display(plt_additional_models) +savefig(plt, joinpath(output_images_path, "mnist_eccco_additional.png")) +``` + ### All digits ```{julia}