From 189818cbc6c1594f22d84c1dcb0463fdd2ebf575 Mon Sep 17 00:00:00 2001
From: Pat Alt <55311242+pat-alt@users.noreply.github.com>
Date: Mon, 24 Apr 2023 18:01:18 +0200
Subject: [PATCH] mnist

---
 Manifest.toml                        |  28 ++++-----
 artifacts/results/mnist_vae.jls      | Bin 1003702 -> 1003702 bytes
 artifacts/results/mnist_vae_weak.jls | Bin 199710 -> 199710 bytes
 notebooks/Manifest.toml              |  24 ++++----
 notebooks/mnist.qmd                  |  81 +++++++++++++++------------
 5 files changed, 70 insertions(+), 63 deletions(-)

diff --git a/Manifest.toml b/Manifest.toml
index ad348f3b..65a6cf1f 100644
--- a/Manifest.toml
+++ b/Manifest.toml
@@ -17,9 +17,9 @@ version = "0.4.4"
 
 [[deps.Accessors]]
 deps = ["Compat", "CompositionsBase", "ConstructionBase", "Dates", "InverseFunctions", "LinearAlgebra", "MacroTools", "Requires", "StaticArrays", "Test"]
-git-tree-sha1 = "beabc31fa319f9de4d16372bff31b4801e43d32c"
+git-tree-sha1 = "c7dddee3f32ceac12abd9a21cd0c4cb489f230d2"
 uuid = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
-version = "0.1.28"
+version = "0.1.29"
 
 [[deps.Adapt]]
 deps = ["LinearAlgebra", "Requires"]
@@ -580,15 +580,15 @@ version = "0.19.3"
 
 [[deps.GR]]
 deps = ["Artifacts", "Base64", "DelimitedFiles", "Downloads", "GR_jll", "HTTP", "JSON", "Libdl", "LinearAlgebra", "Pkg", "Preferences", "Printf", "Random", "Serialization", "Sockets", "TOML", "Tar", "Test", "UUIDs", "p7zip_jll"]
-git-tree-sha1 = "011a22022ed2fb0352a9bded0fa9d3793a8db362"
+git-tree-sha1 = "db730189e3d250d97515a91886de7e33aa8833e6"
 uuid = "28b8d3ca-fb5f-59d9-8090-bfdbd6d07a71"
-version = "0.72.1"
+version = "0.72.2"
 
 [[deps.GR_jll]]
 deps = ["Artifacts", "Bzip2_jll", "Cairo_jll", "FFMPEG_jll", "Fontconfig_jll", "GLFW_jll", "JLLWrappers", "JpegTurbo_jll", "Libdl", "Libtiff_jll", "Pixman_jll", "Qt5Base_jll", "Zlib_jll", "libpng_jll"]
-git-tree-sha1 = "7ea8ead860c85b27e83d198ea54bb2f387db9fc3"
+git-tree-sha1 = "47a2efe07729dd508a032e2f56c46c517481052a"
 uuid = "d2c73de3-f751-5644-a686-071e5b155ba9"
-version = "0.72.1+1"
+version = "0.72.2+0"
 
 [[deps.GZip]]
 deps = ["Libdl"]
@@ -785,7 +785,7 @@ uuid = "0f8b85d8-7281-11e9-16c2-39a750bddbf1"
 version = "1.12.0"
 
 [[deps.JointEnergyModels]]
-deps = ["ChainRulesCore", "ComputationalResources", "Distributions", "Flux", "MLJFlux", "MLJModelInterface", "PkgTemplates", "ProgressMeter", "Random", "StatsBase", "Tables"]
+deps = ["ChainRulesCore", "ComputationalResources", "Distributions", "Flux", "MLJFlux", "MLJModelInterface", "MLUtils", "PkgTemplates", "ProgressMeter", "Random", "StatsBase", "Tables"]
 path = "../JointEnergyModels.jl"
 uuid = "48c56d24-211d-4463-bbc0-7a701b291131"
 version = "0.1.0"
@@ -862,9 +862,9 @@ version = "0.1.2"
 
 [[deps.Latexify]]
 deps = ["Formatting", "InteractiveUtils", "LaTeXStrings", "MacroTools", "Markdown", "OrderedCollections", "Printf", "Requires"]
-git-tree-sha1 = "2422f47b34d4b127720a18f86fa7b1aa2e141f29"
+git-tree-sha1 = "ee342fcc2b8762c43a60dfbbf73bc2258703af19"
 uuid = "23fbe1c1-3f47-55db-b15f-69d7ec21a316"
-version = "0.15.18"
+version = "0.15.19"
 
 [[deps.LazyArtifacts]]
 deps = ["Artifacts", "Pkg"]
@@ -1016,7 +1016,7 @@ version = "0.3.2"
 deps = ["CategoricalArrays", "ColorTypes", "ComputationalResources", "Flux", "MLJModelInterface", "Metalhead", "ProgressMeter", "Random", "Statistics", "Tables"]
 path = "../MLJFlux.jl"
 uuid = "094fc8d1-fd35-5302-93ea-dabda2abf845"
-version = "0.2.9"
+version = "0.2.10"
 
 [[deps.MLJModelInterface]]
 deps = ["Random", "ScientificTypesBase", "StatisticalTraits"]
@@ -1322,9 +1322,9 @@ version = "1.8.0"
 
 [[deps.PkgTemplates]]
 deps = ["Dates", "InteractiveUtils", "LibGit2", "Mocking", "Mustache", "Parameters", "Pkg", "REPL", "UUIDs"]
-git-tree-sha1 = "c0f12580abb41d7d11c1c7c65a1ff410f84c61e3"
+git-tree-sha1 = "b8e88d61d55607c07ac1ed9dabf474cfab8490b9"
 uuid = "14b8a8f1-9102-5b29-a752-f990bacb7fe1"
-version = "0.7.33"
+version = "0.7.34"
 
 [[deps.PlotThemes]]
 deps = ["PlotUtils", "Statistics"]
@@ -1702,9 +1702,9 @@ uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
 
 [[deps.TimerOutputs]]
 deps = ["ExprTools", "Printf"]
-git-tree-sha1 = "f2fd3f288dfc6f507b0c3a2eb3bac009251e548b"
+git-tree-sha1 = "f548a9e9c490030e545f72074a41edfd0e5bcdd7"
 uuid = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f"
-version = "0.5.22"
+version = "0.5.23"
 
 [[deps.Tracker]]
 deps = ["Adapt", "DiffRules", "ForwardDiff", "LinearAlgebra", "LogExpFunctions", "MacroTools", "NNlib", "NaNMath", "Printf", "Random", "Requires", "SpecialFunctions", "Statistics"]
diff --git a/artifacts/results/mnist_vae.jls b/artifacts/results/mnist_vae.jls
index cb9f75b93617511d75340cb68ff57b2b3e59cf74..0dbb47ae5322013c4017e5495b4b7afc1c8d5ad5 100644
GIT binary patch
delta 113
zcmdnC(01EG+lCg#7N!>F7M3lnaWdP#ma?vt*zW7Wx|4DH{C-wO#_j!kK<qdf*7w@m
z7fb+=^ZJ0|?e*<kK+FxqJV4C5y}q5#>}0$9K2{)R17da{=GgAOkJIAhb~Z28ciN1;
E00MI?ga7~l

delta 113
zcmdnC(01EG+lCg#7N!>F7M3lnaWdPhELq=cZ)Z1U-O0H9xC;L}?d`ELtg9rpM@(d8
zWZZth3`Az8g47%|2dZhG5X}X|+(66&#Jt-lMDv-SY>%A83dC$c%nrmH+ao7&nw@0y
F1pr-2Dc=A9

diff --git a/artifacts/results/mnist_vae_weak.jls b/artifacts/results/mnist_vae_weak.jls
index b8bd092b81cda8bbdb19a3daa4d2974588d7513e..e569cb9fc2f85467a252315b1924fbf94eb7fd96 100644
GIT binary patch
delta 66
zcmbO?gJ<3ho`x-qb*$5O2s0IK{~*Go|8@E^F6QIYHAR??PnQR=*8XBN+|K`-aizrc
WwcN}F+iO`F4Zcpl_lwbh(H8)JHX3aJ

delta 66
zcmbO?gJ<3ho`x-qb*$U>i81MaonHKx(Qx`(Zf1k+i~lmNl9*n{%2+u4>2Jp4+uyS?
W0_E5IVLU!v4#=PGE5=m7=nDYKF&x(b

diff --git a/notebooks/Manifest.toml b/notebooks/Manifest.toml
index 899c4ea1..0c580895 100644
--- a/notebooks/Manifest.toml
+++ b/notebooks/Manifest.toml
@@ -700,15 +700,15 @@ version = "0.19.3"
 
 [[deps.GR]]
 deps = ["Artifacts", "Base64", "DelimitedFiles", "Downloads", "GR_jll", "HTTP", "JSON", "Libdl", "LinearAlgebra", "Pkg", "Preferences", "Printf", "Random", "Serialization", "Sockets", "TOML", "Tar", "Test", "UUIDs", "p7zip_jll"]
-git-tree-sha1 = "011a22022ed2fb0352a9bded0fa9d3793a8db362"
+git-tree-sha1 = "db730189e3d250d97515a91886de7e33aa8833e6"
 uuid = "28b8d3ca-fb5f-59d9-8090-bfdbd6d07a71"
-version = "0.72.1"
+version = "0.72.2"
 
 [[deps.GR_jll]]
 deps = ["Artifacts", "Bzip2_jll", "Cairo_jll", "FFMPEG_jll", "Fontconfig_jll", "GLFW_jll", "JLLWrappers", "JpegTurbo_jll", "Libdl", "Libtiff_jll", "Pixman_jll", "Qt5Base_jll", "Zlib_jll", "libpng_jll"]
-git-tree-sha1 = "7ea8ead860c85b27e83d198ea54bb2f387db9fc3"
+git-tree-sha1 = "47a2efe07729dd508a032e2f56c46c517481052a"
 uuid = "d2c73de3-f751-5644-a686-071e5b155ba9"
-version = "0.72.1+1"
+version = "0.72.2+0"
 
 [[deps.GZip]]
 deps = ["Libdl"]
@@ -1052,7 +1052,7 @@ uuid = "0f8b85d8-7281-11e9-16c2-39a750bddbf1"
 version = "1.12.0"
 
 [[deps.JointEnergyModels]]
-deps = ["ChainRulesCore", "ComputationalResources", "Distributions", "Flux", "MLJFlux", "MLJModelInterface", "PkgTemplates", "ProgressMeter", "Random", "StatsBase", "Tables"]
+deps = ["ChainRulesCore", "ComputationalResources", "Distributions", "Flux", "MLJFlux", "MLJModelInterface", "MLUtils", "PkgTemplates", "ProgressMeter", "Random", "StatsBase", "Tables"]
 path = "../../JointEnergyModels.jl"
 uuid = "48c56d24-211d-4463-bbc0-7a701b291131"
 version = "0.1.0"
@@ -1135,9 +1135,9 @@ version = "0.1.2"
 
 [[deps.Latexify]]
 deps = ["Formatting", "InteractiveUtils", "LaTeXStrings", "MacroTools", "Markdown", "OrderedCollections", "Printf", "Requires"]
-git-tree-sha1 = "2422f47b34d4b127720a18f86fa7b1aa2e141f29"
+git-tree-sha1 = "ee342fcc2b8762c43a60dfbbf73bc2258703af19"
 uuid = "23fbe1c1-3f47-55db-b15f-69d7ec21a316"
-version = "0.15.18"
+version = "0.15.19"
 
 [[deps.LazyArtifacts]]
 deps = ["Artifacts", "Pkg"]
@@ -1295,7 +1295,7 @@ version = "0.3.2"
 deps = ["CategoricalArrays", "ColorTypes", "ComputationalResources", "Flux", "MLJModelInterface", "Metalhead", "ProgressMeter", "Random", "Statistics", "Tables"]
 path = "../../MLJFlux.jl"
 uuid = "094fc8d1-fd35-5302-93ea-dabda2abf845"
-version = "0.2.9"
+version = "0.2.10"
 
 [[deps.MLJModelInterface]]
 deps = ["Random", "ScientificTypesBase", "StatisticalTraits"]
@@ -1672,9 +1672,9 @@ version = "1.8.0"
 
 [[deps.PkgTemplates]]
 deps = ["Dates", "InteractiveUtils", "LibGit2", "Mocking", "Mustache", "Parameters", "Pkg", "REPL", "UUIDs"]
-git-tree-sha1 = "c0f12580abb41d7d11c1c7c65a1ff410f84c61e3"
+git-tree-sha1 = "b8e88d61d55607c07ac1ed9dabf474cfab8490b9"
 uuid = "14b8a8f1-9102-5b29-a752-f990bacb7fe1"
-version = "0.7.33"
+version = "0.7.34"
 
 [[deps.PkgVersion]]
 deps = ["Pkg"]
@@ -2169,9 +2169,9 @@ version = "0.3.1"
 
 [[deps.TimerOutputs]]
 deps = ["ExprTools", "Printf"]
-git-tree-sha1 = "f2fd3f288dfc6f507b0c3a2eb3bac009251e548b"
+git-tree-sha1 = "f548a9e9c490030e545f72074a41edfd0e5bcdd7"
 uuid = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f"
-version = "0.5.22"
+version = "0.5.23"
 
 [[deps.Tracker]]
 deps = ["Adapt", "DiffRules", "ForwardDiff", "LinearAlgebra", "LogExpFunctions", "MacroTools", "NNlib", "NaNMath", "Printf", "Random", "Requires", "SpecialFunctions", "Statistics"]
diff --git a/notebooks/mnist.qmd b/notebooks/mnist.qmd
index 2c7f5b9b..696cd786 100644
--- a/notebooks/mnist.qmd
+++ b/notebooks/mnist.qmd
@@ -228,54 +228,61 @@ models = Dict(
 
 ```{julia}
 # Train models:
-function _train(model, X=X, y=labels; cov=.95, method=:simple_inductive)
-    conf_model = conformal_model(jem_ens; method=method, coverage=cov)
+function _train(model, X=X, y=labels; cov=.95, method=:simple_inductive, mod_name="model")
+    conf_model = conformal_model(model; method=method, coverage=cov)
     mach = machine(conf_model, X, y)
+    @info "Begin training $mod_name."
     fit!(mach)
+    @info "Finished training $mod_name."
     M = ECCCo.ConformalModel(mach.model, mach.fitresult)
     return M
 end
-model_dict = Dict(mod_name => _train(mod) for (mod_name, mod) in models)
+model_dict = Dict(mod_name => _train(mod; mod_name=mod_name) for (mod_name, mod) in models)
+Serialization.serialize(joinpath(output_path,"mnist_models.jls"), model_dict)
 ```
 
 ```{julia}
-if mach.model.model isa JointEnergyClassifier
-    sampler = mach.model.model.jem.sampler
-else
-    K = length(counterfactual_data.y_levels)
-    input_size = size(selectdim(counterfactual_data.X, ndims(counterfactual_data.X), 1))
-    𝒟x = Uniform(extrema(counterfactual_data.X)...)
-    𝒟y = Categorical(ones(K) ./ K)
-    sampler = ConditionalSampler(𝒟x, 𝒟y; input_size=input_size)
-end
-opt = ImproperSGLD()
-f(x) = logits(M, x)
-
-n_iter = 200
-_w = 1500
-plts = []
-neach = 10
-for i in 1:10
-    x = sampler(f, opt; niter=n_iter, n_samples=neach, y=i)
-    plts_i = []
-    for j in 1:size(x, 2)
-        xj = x[:,j]
-        xj = reshape(xj, (n_digits, n_digits))
-        plts_i = [plts_i..., Plots.heatmap(rotl90(xj), axis=nothing, cb=false)]
+for (mod_name, mod) in model_dict
+
+    # Plot:
+    if mod.model.model isa JointEnergyClassifier
+        sampler = mod.model.model.jem.sampler
+    elseif mod.model.model.model isa JointEnergyClassifier
+        sampler = mod.model.model.model.jem.sampler
+    else
+        K = length(counterfactual_data.y_levels)
+        input_size = size(selectdim(counterfactual_data.X, ndims(counterfactual_data.X), 1))
+        𝒟x = Uniform(extrema(counterfactual_data.X)...)
+        𝒟y = Categorical(ones(K) ./ K)
+        sampler = ConditionalSampler(𝒟x, 𝒟y; input_size=input_size)
     end
-    plt = Plots.plot(plts_i..., size=(_w,0.10*_w), layout=(1,10))
-    plts = [plts..., plt]
-end
-plt = Plots.plot(plts..., size=(_w,_w), layout=(10,1))
-display(plt)
+    opt = ImproperSGLD()
+    f(x) = logits(mod, x)
 
-```
+    n_iter = 200
+    _w = 1500
+    plts = []
+    neach = 10
+    for i in 1:10
+        x = sampler(f, opt; niter=n_iter, n_samples=neach, y=i)
+        plts_i = []
+        for j in 1:size(x, 2)
+            xj = x[:,j]
+            xj = reshape(xj, (n_digits, n_digits))
+            plts_i = [plts_i..., Plots.heatmap(rotl90(xj), axis=nothing, cb=false)]
+        end
+        plt = Plots.plot(plts_i..., size=(_w,0.10*_w), layout=(1,10))
+        plts = [plts..., plt]
+    end
+    plt = Plots.plot(plts..., size=(_w,_w), layout=(10,1), plot_title=mod_name)
+    display(plt)
 
-```{julia}
-test_data = load_mnist_test()
-test_data.X = pre_process.(test_data.X)
-f1 = CounterfactualExplanations.Models.model_evaluation(M, test_data)
-println("F1 score (test): $(round(f1,digits=3))")
+    # Test performance:
+    test_data = load_mnist_test()
+    test_data.X = pre_process.(test_data.X)
+    f1 = CounterfactualExplanations.Models.model_evaluation(M, test_data)
+    println("F1 score (test): $(round(f1,digits=3))")
+end
 ```
 
 ```{julia}
-- 
GitLab