diff --git a/.gitignore b/.gitignore index 632f382d3c3f6639eee8cb7dba5348c9e0255ec8..66c97be110009856c9648c45dac402b12daa0619 100644 --- a/.gitignore +++ b/.gitignore @@ -3,6 +3,7 @@ /.quarto/ /Manifest.toml /replicated/ +*/.CondaPkg/ # Tex diff --git a/artifacts/results/images/mnist_generated_JEM Ensemble.png b/artifacts/results/images/mnist_generated_JEM Ensemble.png index 727cb76a00365999bddc72e8bacba2e624d2abca..dc03f1c252fd1c91c0cbce49f1edc9ef7ba414f1 100644 Binary files a/artifacts/results/images/mnist_generated_JEM Ensemble.png and b/artifacts/results/images/mnist_generated_JEM Ensemble.png differ diff --git a/artifacts/results/images/mnist_generated_JEM.png b/artifacts/results/images/mnist_generated_JEM.png index 66e167e27972bafccda5fbca8831879ce375822c..ba4a42f4d9903ae16108eeceee2be2e8d19ed8e1 100644 Binary files a/artifacts/results/images/mnist_generated_JEM.png and b/artifacts/results/images/mnist_generated_JEM.png differ diff --git a/artifacts/results/images/mnist_generated_MLP Ensemble.png b/artifacts/results/images/mnist_generated_MLP Ensemble.png index f9f9b1c230769ff79a383f49a5ce41d2b25556ca..32e1c35d94ef76dcc1fac95e1b4d50f40b203b85 100644 Binary files a/artifacts/results/images/mnist_generated_MLP Ensemble.png and b/artifacts/results/images/mnist_generated_MLP Ensemble.png differ diff --git a/artifacts/results/images/mnist_generated_MLP.png b/artifacts/results/images/mnist_generated_MLP.png index 7ebfa028d077dd7755c0f37fb5e292f455b69582..24333578aa201ee4e90e1a2bda87c58bd08bfb00 100644 Binary files a/artifacts/results/images/mnist_generated_MLP.png and b/artifacts/results/images/mnist_generated_MLP.png differ diff --git a/artifacts/results/mnist_vae.jls b/artifacts/results/mnist_vae.jls index c8ef93995d9c56559b797f3c98aa10310b371c86..a55474be7a30c77ab1b046ad0c2794523c514d94 100644 Binary files a/artifacts/results/mnist_vae.jls and b/artifacts/results/mnist_vae.jls differ diff --git a/artifacts/results/mnist_vae_weak.jls b/artifacts/results/mnist_vae_weak.jls index 93e046e518d94b1923573058ec144cac3acbe2b1..8ed61727f0b7c3e5775110beeca6df7f9c9d10ca 100644 Binary files a/artifacts/results/mnist_vae_weak.jls and b/artifacts/results/mnist_vae_weak.jls differ diff --git a/notebooks/Manifest.toml b/notebooks/Manifest.toml index 95a166d6a93bd1e4f8e35c722449817b9ac28d63..b368fc0c1d35d718268f6afeb96c43ddac637d05 100644 --- a/notebooks/Manifest.toml +++ b/notebooks/Manifest.toml @@ -2,7 +2,7 @@ julia_version = "1.9.0" manifest_format = "2.0" -project_hash = "0abd725f0aec8ca9f3de1d681c2fe36ba5f06d09" +project_hash = "b3a28b3d457734def39d7f7d129cad9e88b83474" [[deps.ARFFFiles]] deps = ["CategoricalArrays", "Dates", "Parsers", "Tables"] @@ -1392,8 +1392,8 @@ uuid = "b964fa9f-0449-5b57-a5c2-d3ea65f4040f" version = "1.3.0" [[deps.LaplaceRedux]] -deps = ["Flux", "LinearAlgebra", "Parameters", "Plots", "Zygote"] -git-tree-sha1 = "a4adebbeafb96d0864b4833c254013f66dc6e0ee" +deps = ["CSV", "Compat", "ComputationalResources", "DataFrames", "Flux", "LinearAlgebra", "MLJ", "MLJBase", "MLJFlux", "MLJModelInterface", "MLUtils", "Parameters", "Plots", "ProgressMeter", "Random", "Serialization", "Statistics", "Tables", "Tullio", "Zygote"] +path = "../../LaplaceRedux.jl" uuid = "c52c1a26-f7c5-402b-80be-ba1e638ad478" version = "0.1.2" @@ -2781,6 +2781,12 @@ git-tree-sha1 = "4d4ed7f294cda19382ff7de4c137d24d16adc89b" uuid = "981d1d27-644d-49a2-9326-4793e63143c3" version = "0.1.0" +[[deps.Tullio]] +deps = ["ChainRulesCore", "DiffRules", "LinearAlgebra", "Requires"] +git-tree-sha1 = "7871a39eac745697ee512a87eeff06a048a7905b" +uuid = "bc48ee85-29a4-5162-ae0b-a64e1601d4bc" +version = "0.3.5" + [[deps.TupleTools]] git-tree-sha1 = "3c712976c47707ff893cf6ba4354aa14db1d8938" uuid = "9d95972d-f1c8-5527-a6e0-b4b365fa01f6" diff --git a/notebooks/Project.toml b/notebooks/Project.toml index c5d9bbffa366b1355b71ad4dff0808c1a098b2b3..f2fec0724ef5c015c3f75aed572ee906ce962b09 100644 --- a/notebooks/Project.toml +++ b/notebooks/Project.toml @@ -16,6 +16,7 @@ ECCCo = "0232c203-4013-4b0d-ad96-43e3e11ac3bf" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" Images = "916415d5-f1e6-5110-898d-aaa5f9f070e0" JointEnergyModels = "48c56d24-211d-4463-bbc0-7a701b291131" +LaplaceRedux = "c52c1a26-f7c5-402b-80be-ba1e638ad478" LazyArtifacts = "4af54fe1-eca0-43a8-85a7-787d91b784e3" MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458" MLJ = "add582a8-e3ab-11e8-2d5e-e98b27df1bc7" diff --git a/notebooks/mnist.qmd b/notebooks/mnist.qmd index 24fc7cca3378a36f04223ef7cce1b0b90dbfdda6..feedfb983c656e54325ec8ae822e7ddb7bd7cc7b 100644 --- a/notebooks/mnist.qmd +++ b/notebooks/mnist.qmd @@ -164,11 +164,11 @@ end ```{julia} # Hyper: -_retrain = true +_retrain = false _regen = true # Data: -n_obs = nothing +n_obs = 10000 counterfactual_data = load_mnist(n_obs) counterfactual_data.X = pre_process.(counterfactual_data.X) counterfactual_data.generative_model = vae @@ -441,60 +441,56 @@ savefig(plt, joinpath(output_images_path, "mnist_eccco.png")) #### Additional Models (not in paper) -An MLP with regularization: +LeNet-5: ```{julia} -mlp_reg = NeuralNetworkClassifier( - builder=builder, - epochs=epochs, - batch_size=batch_size, - finaliser=_finaliser, - loss=_loss, - lambda=1.0, - alpha=0.5, -) -``` - -A CNN: - -```{julia} -mutable struct MyConvBuilder +mutable struct LeNetBuilder filter_size::Int channels1::Int channels2::Int - channels3::Int end -function MLJFlux.build(b::MyConvBuilder, rng, n_in, n_out, n_channels) +function MLJFlux.build(b::LeNetBuilder, rng, n_in, n_out, n_channels) - k, c1, c2, c3 = b.filter_size, b.channels1, b.channels2, b.channels3 + k, c1, c2 = b.filter_size, b.channels1, b.channels2 mod(k, 2) == 1 || error("`filter_size` must be odd. ") # padding to preserve image size on convolution: p = div(k - 1, 2) - front = Chain( + front = Flux.Chain( Conv((k, k), n_channels => c1, pad=(p, p), relu), MaxPool((2, 2)), Conv((k, k), c1 => c2, pad=(p, p), relu), MaxPool((2, 2)), - Conv((k, k), c2 => c3, pad=(p, p), relu), - MaxPool((2 ,2)), - flatten) + Flux.flatten) d = Flux.outputsize(front, (n_in..., n_channels, 1)) |> first - return Chain(front, Dense(d, n_out)) + back = Flux.Chain( + Dense(256, 120, relu), + Dense(120, 84, relu), + Dense(84, n_out), + ) + return Flux.Chain(front, back) end -clf = ImageClassifier(builder=MyConvBuilder(3, 16, 32, 32), - epochs=10, - loss=Flux.crossentropy) +# Final model: +lenet = ImageClassifier( + builder=LeNetBuilder(5, 6, 16), + epochs=epochs, + batch_size=batch_size, + finaliser=_finaliser, + loss=_loss, +) + +# Convert to image: function toimg(X) - X = MLUtils.stack(map(x -> reshape(x, 28,28), eachrow(matrix(X)))) + X = matrix(X) |> x -> reshape(x, 28, 28, :) X = coerce(X, GrayImage) return X end -clf = (X -> toimg(X)) |> clf + +lenet = (X -> toimg(X)) |> lenet ``` ```{julia} @@ -502,11 +498,10 @@ add_retrain = true # Deep Ensemble: mlp_large_ens = EnsembleModel(model=mlp, n=50) -mlp_drop_large_ens = EnsembleModel(model=mlp_reg, n=50) add_models = Dict( "Large Ensemble" => mlp_large_ens, - "Large Ensemble Reg" => mlp_drop_large_ens, + "LeNet-5" => lenet, ) if add_retrain @@ -519,13 +514,20 @@ end ``` ```{julia} +_plt_order = [ + "MLP", + "MLP Ensemble", + "Large Ensemble", + "JEM", + "JEM Ensemble", +] plt_additional_models, _, _ces_ = _plot_eccco_mnist( - plt_order = ["MLP", "MLP Ensemble", "Large Ensemble", "Large Ensemble Reg", "JEM", "JEM Ensemble"], + plt_order = _plt_order, model_dict=large_model_dict, wide = true, ) display(plt_additional_models) -savefig(plt, joinpath(output_images_path, "mnist_eccco_additional.png")) +savefig(plt_additional_models, joinpath(output_images_path, "mnist_eccco_additional.png")) ``` ### All digits diff --git a/notebooks/setup.jl b/notebooks/setup.jl index 4a17f7284cd597ff8c9c27eea5aa2d44081e8bf0..6ab61c89c272239e56bd1b7f8bec6e907d2dbf0d 100644 --- a/notebooks/setup.jl +++ b/notebooks/setup.jl @@ -24,6 +24,7 @@ setup_notebooks = quote using Flux using Images using JointEnergyModels + using LaplaceRedux: LaplaceApproximation using LinearAlgebra using Markdown using MLDatasets