From 189818cbc6c1594f22d84c1dcb0463fdd2ebf575 Mon Sep 17 00:00:00 2001 From: Pat Alt <55311242+pat-alt@users.noreply.github.com> Date: Mon, 24 Apr 2023 18:01:18 +0200 Subject: [PATCH] mnist --- Manifest.toml | 28 ++++----- artifacts/results/mnist_vae.jls | Bin 1003702 -> 1003702 bytes artifacts/results/mnist_vae_weak.jls | Bin 199710 -> 199710 bytes notebooks/Manifest.toml | 24 ++++---- notebooks/mnist.qmd | 81 +++++++++++++++------------ 5 files changed, 70 insertions(+), 63 deletions(-) diff --git a/Manifest.toml b/Manifest.toml index ad348f3b..65a6cf1f 100644 --- a/Manifest.toml +++ b/Manifest.toml @@ -17,9 +17,9 @@ version = "0.4.4" [[deps.Accessors]] deps = ["Compat", "CompositionsBase", "ConstructionBase", "Dates", "InverseFunctions", "LinearAlgebra", "MacroTools", "Requires", "StaticArrays", "Test"] -git-tree-sha1 = "beabc31fa319f9de4d16372bff31b4801e43d32c" +git-tree-sha1 = "c7dddee3f32ceac12abd9a21cd0c4cb489f230d2" uuid = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" -version = "0.1.28" +version = "0.1.29" [[deps.Adapt]] deps = ["LinearAlgebra", "Requires"] @@ -580,15 +580,15 @@ version = "0.19.3" [[deps.GR]] deps = ["Artifacts", "Base64", "DelimitedFiles", "Downloads", "GR_jll", "HTTP", "JSON", "Libdl", "LinearAlgebra", "Pkg", "Preferences", "Printf", "Random", "Serialization", "Sockets", "TOML", "Tar", "Test", "UUIDs", "p7zip_jll"] -git-tree-sha1 = "011a22022ed2fb0352a9bded0fa9d3793a8db362" +git-tree-sha1 = "db730189e3d250d97515a91886de7e33aa8833e6" uuid = "28b8d3ca-fb5f-59d9-8090-bfdbd6d07a71" -version = "0.72.1" +version = "0.72.2" [[deps.GR_jll]] deps = ["Artifacts", "Bzip2_jll", "Cairo_jll", "FFMPEG_jll", "Fontconfig_jll", "GLFW_jll", "JLLWrappers", "JpegTurbo_jll", "Libdl", "Libtiff_jll", "Pixman_jll", "Qt5Base_jll", "Zlib_jll", "libpng_jll"] -git-tree-sha1 = "7ea8ead860c85b27e83d198ea54bb2f387db9fc3" +git-tree-sha1 = "47a2efe07729dd508a032e2f56c46c517481052a" uuid = "d2c73de3-f751-5644-a686-071e5b155ba9" -version = "0.72.1+1" +version = "0.72.2+0" [[deps.GZip]] deps = ["Libdl"] @@ -785,7 +785,7 @@ uuid = "0f8b85d8-7281-11e9-16c2-39a750bddbf1" version = "1.12.0" [[deps.JointEnergyModels]] -deps = ["ChainRulesCore", "ComputationalResources", "Distributions", "Flux", "MLJFlux", "MLJModelInterface", "PkgTemplates", "ProgressMeter", "Random", "StatsBase", "Tables"] +deps = ["ChainRulesCore", "ComputationalResources", "Distributions", "Flux", "MLJFlux", "MLJModelInterface", "MLUtils", "PkgTemplates", "ProgressMeter", "Random", "StatsBase", "Tables"] path = "../JointEnergyModels.jl" uuid = "48c56d24-211d-4463-bbc0-7a701b291131" version = "0.1.0" @@ -862,9 +862,9 @@ version = "0.1.2" [[deps.Latexify]] deps = ["Formatting", "InteractiveUtils", "LaTeXStrings", "MacroTools", "Markdown", "OrderedCollections", "Printf", "Requires"] -git-tree-sha1 = "2422f47b34d4b127720a18f86fa7b1aa2e141f29" +git-tree-sha1 = "ee342fcc2b8762c43a60dfbbf73bc2258703af19" uuid = "23fbe1c1-3f47-55db-b15f-69d7ec21a316" -version = "0.15.18" +version = "0.15.19" [[deps.LazyArtifacts]] deps = ["Artifacts", "Pkg"] @@ -1016,7 +1016,7 @@ version = "0.3.2" deps = ["CategoricalArrays", "ColorTypes", "ComputationalResources", "Flux", "MLJModelInterface", "Metalhead", "ProgressMeter", "Random", "Statistics", "Tables"] path = "../MLJFlux.jl" uuid = "094fc8d1-fd35-5302-93ea-dabda2abf845" -version = "0.2.9" +version = "0.2.10" [[deps.MLJModelInterface]] deps = ["Random", "ScientificTypesBase", "StatisticalTraits"] @@ -1322,9 +1322,9 @@ version = "1.8.0" [[deps.PkgTemplates]] deps = ["Dates", "InteractiveUtils", "LibGit2", "Mocking", "Mustache", "Parameters", "Pkg", "REPL", "UUIDs"] -git-tree-sha1 = "c0f12580abb41d7d11c1c7c65a1ff410f84c61e3" +git-tree-sha1 = "b8e88d61d55607c07ac1ed9dabf474cfab8490b9" uuid = "14b8a8f1-9102-5b29-a752-f990bacb7fe1" -version = "0.7.33" +version = "0.7.34" [[deps.PlotThemes]] deps = ["PlotUtils", "Statistics"] @@ -1702,9 +1702,9 @@ uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [[deps.TimerOutputs]] deps = ["ExprTools", "Printf"] -git-tree-sha1 = "f2fd3f288dfc6f507b0c3a2eb3bac009251e548b" +git-tree-sha1 = "f548a9e9c490030e545f72074a41edfd0e5bcdd7" uuid = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f" -version = "0.5.22" +version = "0.5.23" [[deps.Tracker]] deps = ["Adapt", "DiffRules", "ForwardDiff", "LinearAlgebra", "LogExpFunctions", "MacroTools", "NNlib", "NaNMath", "Printf", "Random", "Requires", "SpecialFunctions", "Statistics"] diff --git a/artifacts/results/mnist_vae.jls b/artifacts/results/mnist_vae.jls index cb9f75b93617511d75340cb68ff57b2b3e59cf74..0dbb47ae5322013c4017e5495b4b7afc1c8d5ad5 100644 GIT binary patch delta 113 zcmdnC(01EG+lCg#7N!>F7M3lnaWdP#ma?vt*zW7Wx|4DH{C-wO#_j!kK<qdf*7w@m z7fb+=^ZJ0|?e*<kK+FxqJV4C5y}q5#>}0$9K2{)R17da{=GgAOkJIAhb~Z28ciN1; E00MI?ga7~l delta 113 zcmdnC(01EG+lCg#7N!>F7M3lnaWdPhELq=cZ)Z1U-O0H9xC;L}?d`ELtg9rpM@(d8 zWZZth3`Az8g47%|2dZhG5X}X|+(66&#Jt-lMDv-SY>%A83dC$c%nrmH+ao7&nw@0y F1pr-2Dc=A9 diff --git a/artifacts/results/mnist_vae_weak.jls b/artifacts/results/mnist_vae_weak.jls index b8bd092b81cda8bbdb19a3daa4d2974588d7513e..e569cb9fc2f85467a252315b1924fbf94eb7fd96 100644 GIT binary patch delta 66 zcmbO?gJ<3ho`x-qb*$5O2s0IK{~*Go|8@E^F6QIYHAR??PnQR=*8XBN+|K`-aizrc WwcN}F+iO`F4Zcpl_lwbh(H8)JHX3aJ delta 66 zcmbO?gJ<3ho`x-qb*$U>i81MaonHKx(Qx`(Zf1k+i~lmNl9*n{%2+u4>2Jp4+uyS? W0_E5IVLU!v4#=PGE5=m7=nDYKF&x(b diff --git a/notebooks/Manifest.toml b/notebooks/Manifest.toml index 899c4ea1..0c580895 100644 --- a/notebooks/Manifest.toml +++ b/notebooks/Manifest.toml @@ -700,15 +700,15 @@ version = "0.19.3" [[deps.GR]] deps = ["Artifacts", "Base64", "DelimitedFiles", "Downloads", "GR_jll", "HTTP", "JSON", "Libdl", "LinearAlgebra", "Pkg", "Preferences", "Printf", "Random", "Serialization", "Sockets", "TOML", "Tar", "Test", "UUIDs", "p7zip_jll"] -git-tree-sha1 = "011a22022ed2fb0352a9bded0fa9d3793a8db362" +git-tree-sha1 = "db730189e3d250d97515a91886de7e33aa8833e6" uuid = "28b8d3ca-fb5f-59d9-8090-bfdbd6d07a71" -version = "0.72.1" +version = "0.72.2" [[deps.GR_jll]] deps = ["Artifacts", "Bzip2_jll", "Cairo_jll", "FFMPEG_jll", "Fontconfig_jll", "GLFW_jll", "JLLWrappers", "JpegTurbo_jll", "Libdl", "Libtiff_jll", "Pixman_jll", "Qt5Base_jll", "Zlib_jll", "libpng_jll"] -git-tree-sha1 = "7ea8ead860c85b27e83d198ea54bb2f387db9fc3" +git-tree-sha1 = "47a2efe07729dd508a032e2f56c46c517481052a" uuid = "d2c73de3-f751-5644-a686-071e5b155ba9" -version = "0.72.1+1" +version = "0.72.2+0" [[deps.GZip]] deps = ["Libdl"] @@ -1052,7 +1052,7 @@ uuid = "0f8b85d8-7281-11e9-16c2-39a750bddbf1" version = "1.12.0" [[deps.JointEnergyModels]] -deps = ["ChainRulesCore", "ComputationalResources", "Distributions", "Flux", "MLJFlux", "MLJModelInterface", "PkgTemplates", "ProgressMeter", "Random", "StatsBase", "Tables"] +deps = ["ChainRulesCore", "ComputationalResources", "Distributions", "Flux", "MLJFlux", "MLJModelInterface", "MLUtils", "PkgTemplates", "ProgressMeter", "Random", "StatsBase", "Tables"] path = "../../JointEnergyModels.jl" uuid = "48c56d24-211d-4463-bbc0-7a701b291131" version = "0.1.0" @@ -1135,9 +1135,9 @@ version = "0.1.2" [[deps.Latexify]] deps = ["Formatting", "InteractiveUtils", "LaTeXStrings", "MacroTools", "Markdown", "OrderedCollections", "Printf", "Requires"] -git-tree-sha1 = "2422f47b34d4b127720a18f86fa7b1aa2e141f29" +git-tree-sha1 = "ee342fcc2b8762c43a60dfbbf73bc2258703af19" uuid = "23fbe1c1-3f47-55db-b15f-69d7ec21a316" -version = "0.15.18" +version = "0.15.19" [[deps.LazyArtifacts]] deps = ["Artifacts", "Pkg"] @@ -1295,7 +1295,7 @@ version = "0.3.2" deps = ["CategoricalArrays", "ColorTypes", "ComputationalResources", "Flux", "MLJModelInterface", "Metalhead", "ProgressMeter", "Random", "Statistics", "Tables"] path = "../../MLJFlux.jl" uuid = "094fc8d1-fd35-5302-93ea-dabda2abf845" -version = "0.2.9" +version = "0.2.10" [[deps.MLJModelInterface]] deps = ["Random", "ScientificTypesBase", "StatisticalTraits"] @@ -1672,9 +1672,9 @@ version = "1.8.0" [[deps.PkgTemplates]] deps = ["Dates", "InteractiveUtils", "LibGit2", "Mocking", "Mustache", "Parameters", "Pkg", "REPL", "UUIDs"] -git-tree-sha1 = "c0f12580abb41d7d11c1c7c65a1ff410f84c61e3" +git-tree-sha1 = "b8e88d61d55607c07ac1ed9dabf474cfab8490b9" uuid = "14b8a8f1-9102-5b29-a752-f990bacb7fe1" -version = "0.7.33" +version = "0.7.34" [[deps.PkgVersion]] deps = ["Pkg"] @@ -2169,9 +2169,9 @@ version = "0.3.1" [[deps.TimerOutputs]] deps = ["ExprTools", "Printf"] -git-tree-sha1 = "f2fd3f288dfc6f507b0c3a2eb3bac009251e548b" +git-tree-sha1 = "f548a9e9c490030e545f72074a41edfd0e5bcdd7" uuid = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f" -version = "0.5.22" +version = "0.5.23" [[deps.Tracker]] deps = ["Adapt", "DiffRules", "ForwardDiff", "LinearAlgebra", "LogExpFunctions", "MacroTools", "NNlib", "NaNMath", "Printf", "Random", "Requires", "SpecialFunctions", "Statistics"] diff --git a/notebooks/mnist.qmd b/notebooks/mnist.qmd index 2c7f5b9b..696cd786 100644 --- a/notebooks/mnist.qmd +++ b/notebooks/mnist.qmd @@ -228,54 +228,61 @@ models = Dict( ```{julia} # Train models: -function _train(model, X=X, y=labels; cov=.95, method=:simple_inductive) - conf_model = conformal_model(jem_ens; method=method, coverage=cov) +function _train(model, X=X, y=labels; cov=.95, method=:simple_inductive, mod_name="model") + conf_model = conformal_model(model; method=method, coverage=cov) mach = machine(conf_model, X, y) + @info "Begin training $mod_name." fit!(mach) + @info "Finished training $mod_name." M = ECCCo.ConformalModel(mach.model, mach.fitresult) return M end -model_dict = Dict(mod_name => _train(mod) for (mod_name, mod) in models) +model_dict = Dict(mod_name => _train(mod; mod_name=mod_name) for (mod_name, mod) in models) +Serialization.serialize(joinpath(output_path,"mnist_models.jls"), model_dict) ``` ```{julia} -if mach.model.model isa JointEnergyClassifier - sampler = mach.model.model.jem.sampler -else - K = length(counterfactual_data.y_levels) - input_size = size(selectdim(counterfactual_data.X, ndims(counterfactual_data.X), 1)) - ð’Ÿx = Uniform(extrema(counterfactual_data.X)...) - ð’Ÿy = Categorical(ones(K) ./ K) - sampler = ConditionalSampler(ð’Ÿx, ð’Ÿy; input_size=input_size) -end -opt = ImproperSGLD() -f(x) = logits(M, x) - -n_iter = 200 -_w = 1500 -plts = [] -neach = 10 -for i in 1:10 - x = sampler(f, opt; niter=n_iter, n_samples=neach, y=i) - plts_i = [] - for j in 1:size(x, 2) - xj = x[:,j] - xj = reshape(xj, (n_digits, n_digits)) - plts_i = [plts_i..., Plots.heatmap(rotl90(xj), axis=nothing, cb=false)] +for (mod_name, mod) in model_dict + + # Plot: + if mod.model.model isa JointEnergyClassifier + sampler = mod.model.model.jem.sampler + elseif mod.model.model.model isa JointEnergyClassifier + sampler = mod.model.model.model.jem.sampler + else + K = length(counterfactual_data.y_levels) + input_size = size(selectdim(counterfactual_data.X, ndims(counterfactual_data.X), 1)) + ð’Ÿx = Uniform(extrema(counterfactual_data.X)...) + ð’Ÿy = Categorical(ones(K) ./ K) + sampler = ConditionalSampler(ð’Ÿx, ð’Ÿy; input_size=input_size) end - plt = Plots.plot(plts_i..., size=(_w,0.10*_w), layout=(1,10)) - plts = [plts..., plt] -end -plt = Plots.plot(plts..., size=(_w,_w), layout=(10,1)) -display(plt) + opt = ImproperSGLD() + f(x) = logits(mod, x) -``` + n_iter = 200 + _w = 1500 + plts = [] + neach = 10 + for i in 1:10 + x = sampler(f, opt; niter=n_iter, n_samples=neach, y=i) + plts_i = [] + for j in 1:size(x, 2) + xj = x[:,j] + xj = reshape(xj, (n_digits, n_digits)) + plts_i = [plts_i..., Plots.heatmap(rotl90(xj), axis=nothing, cb=false)] + end + plt = Plots.plot(plts_i..., size=(_w,0.10*_w), layout=(1,10)) + plts = [plts..., plt] + end + plt = Plots.plot(plts..., size=(_w,_w), layout=(10,1), plot_title=mod_name) + display(plt) -```{julia} -test_data = load_mnist_test() -test_data.X = pre_process.(test_data.X) -f1 = CounterfactualExplanations.Models.model_evaluation(M, test_data) -println("F1 score (test): $(round(f1,digits=3))") + # Test performance: + test_data = load_mnist_test() + test_data.X = pre_process.(test_data.X) + f1 = CounterfactualExplanations.Models.model_evaluation(M, test_data) + println("F1 score (test): $(round(f1,digits=3))") +end ``` ```{julia} -- GitLab