From a933cc9bb2f40a313660ebefd37c6aa55d828c67 Mon Sep 17 00:00:00 2001 From: Pat Alt <55311242+pat-alt@users.noreply.github.com> Date: Tue, 8 Aug 2023 14:10:45 +0200 Subject: [PATCH] LeNet implemented --- artifacts/results/mnist_vae.jls | Bin 1003682 -> 1003682 bytes artifacts/results/mnist_vae_weak.jls | Bin 105130 -> 105130 bytes notebooks/mnist.qmd | 35 ++++++++++++--------------- src/model.jl | 32 +++++++++++++----------- 4 files changed, 34 insertions(+), 33 deletions(-) diff --git a/artifacts/results/mnist_vae.jls b/artifacts/results/mnist_vae.jls index 872e5a43473de969e478c182fe76d324da88d0dc..0eed13cca664e891d5f09c05ef2f341b4fe67efd 100644 GIT binary patch delta 113 zcmZ3q(00*6+lCg#7N!>F7M3lnUNYM^wQ}ub-2Tavb*04iE4p0owYPi8umVL6M}pY1 zlUNxU+vldR0x=s9vjZ{5_PHsX<|o^;+qr<48;E&;n0I@2JD=If?X#0vnHaY-%dx)G HX7mLB{PQbP delta 113 zcmZ3q(00*6+lCg#7N!>F7M3lnUNYNfC$a8i+`d_b|DE=BCOOts65E|8ure}k-=xgX z$k^^Okrjy9fS4VKIkvk@<TO9IeT^CGd+qHZDIn5I2B^QiI*JR3xq+Amh<UeHNAa1R HWb_395Bex| diff --git a/artifacts/results/mnist_vae_weak.jls b/artifacts/results/mnist_vae_weak.jls index 42d8d05d0113c303f09bad12ce13865f76cff0cb..d1c8752d6b813e87b8ae1e8c82bf6125a31a30d7 100644 GIT binary patch delta 63 zcmZ3rm2K5lwhdlg(-*QbmT&gzT3<EYla0}Hx)zYT{WcrpN{Q{i<r&viP5;5dXf$1d TjnQy=i#%ib^hs=t9*n*KAH5cv delta 62 zcmZ3rm2K5lwhdlgn-6xaubSR5lhJVcPFBY9$zEOM(|3Y+1+0vo(-UVgdTiHcV_YRM S{nbpy^6fX}7}r%X`T_uBf*HL4 diff --git a/notebooks/mnist.qmd b/notebooks/mnist.qmd index 3fa961e0..b5b68d58 100644 --- a/notebooks/mnist.qmd +++ b/notebooks/mnist.qmd @@ -388,11 +388,6 @@ function _plot_eccco_mnist( ces = Dict() for (mod_name, mod) in model_dict - if mod.model.model isa ProbabilisticPipeline - x = mod.model.model.f(x) - _data = counterfactual_data - _data.X = mod.model.model.f(_data.X) - end ce = generate_counterfactual( x, target, counterfactual_data, mod, eccco_generator; decision_threshold=γ, max_iter=T, @@ -455,7 +450,11 @@ mutable struct LeNetBuilder channels2::Int end -function MLJFlux.build(b::LeNetBuilder, rng, n_in, n_out, n_channels) +preproc(X) = reshape(X, (28, 28, 1, :)) + +function MLJFlux.build(b::LeNetBuilder, rng, n_in, n_out) + + _n_in = Int(sqrt(n_in)) k, c1, c2 = b.filter_size, b.channels1, b.channels2 @@ -464,42 +463,35 @@ function MLJFlux.build(b::LeNetBuilder, rng, n_in, n_out, n_channels) # padding to preserve image size on convolution: p = div(k - 1, 2) + preproc(x) = reshape(x, (_n_in, _n_in, 1, :)) + front = Flux.Chain( - Conv((k, k), n_channels => c1, pad=(p, p), relu), + Conv((k, k), 1 => c1, pad=(p, p), relu), MaxPool((2, 2)), Conv((k, k), c1 => c2, pad=(p, p), relu), MaxPool((2, 2)), Flux.flatten ) - d = Flux.outputsize(front, (n_in..., n_channels, 1)) |> first + d = Flux.outputsize(front, (_n_in, _n_in, 1, 1)) |> first back = Flux.Chain( Dense(d, 120, relu), Dense(120, 84, relu), Dense(84, n_out), ) - chain = Flux.Chain(front, back) + chain = Flux.Chain(preproc, front, back) return chain end # Final model: -lenet = ImageClassifier( +lenet = NeuralNetworkClassifier( builder=LeNetBuilder(5, 6, 16), epochs=epochs, batch_size=batch_size, finaliser=_finaliser, loss=_loss, ) - -# Convert to image: -function toimg(X) - X = matrix(X) |> x -> reshape(x, 28, 28, :) - X = coerce(X, GrayImage) - return X -end - -lenet = (X -> toimg(X)) |> lenet ``` ```{julia} @@ -508,8 +500,12 @@ add_retrain = true # Deep Ensemble: mlp_large_ens = EnsembleModel(model=mlp, n=50) +# LeNet-5 Ensemble: +lenet_ens = EnsembleModel(model=lenet, n=5) + add_models = Dict( "LeNet-5" => lenet, + "LeNet-5 Ensemble" => lenet_ens, "Large Ensemble (n=50)" => mlp_large_ens, ) @@ -528,6 +524,7 @@ _plt_order = [ "MLP Ensemble", "Large Ensemble (n=50)", "LeNet-5", + "LeNet-5 Ensemble", "JEM", "JEM Ensemble", ] diff --git a/src/model.jl b/src/model.jl index bf3dbe8a..a865d559 100644 --- a/src/model.jl +++ b/src/model.jl @@ -1,3 +1,4 @@ +using ChainRules: ignore_derivatives using ConformalPrediction using CounterfactualExplanations.Models using Flux @@ -38,13 +39,18 @@ end Private function that extracts the chains from a fitted model. """ function _get_chains(fitresult) - if fitresult isa MLJEnsembles.WrappedEnsemble - chains = map(res -> res[1], fitresult.ensemble) - elseif fitresult isa MLJBase.Signature - chains = [fitted_params(fitresult)[:image_classifier][1]] # for piped image classifiers - else - chains = [fitresult[1]] + + chains = [] + + ignore_derivatives() do + if fitresult isa MLJEnsembles.WrappedEnsemble + _chains = map(res -> res[1], fitresult.ensemble) + else + _chains = [fitresult[1]] + end + push!(chains, _chains...) end + return chains end @@ -147,7 +153,9 @@ In the binary case logits are fed through the sigmoid function instead of softma which follows from the derivation here: https://stats.stackexchange.com/questions/233658/softmax-vs-sigmoid-function-in-logistic-classifier """ function Models.logits(M::ConformalModel, X::AbstractArray) + fitresult = M.fitresult + function predict_logits(fitresult, x) ŷ = MLUtils.stack(map(chain -> get_logits(chain,x),_get_chains(fitresult))) |> y -> mean(y, dims=ndims(y)) |> @@ -162,14 +170,9 @@ function Models.logits(M::ConformalModel, X::AbstractArray) ŷ = ndims(ŷ) > 1 ? ŷ : permutedims([ŷ]) return ŷ end - if ndims(X) > 2 - yhat = map(eachslice(X, dims=ndims(X))) do x - predict_logits(fitresult, x) - end - yhat = MLUtils.stack(yhat) - else - yhat = predict_logits(fitresult, X) - end + + yhat = predict_logits(fitresult, X) + return yhat end @@ -179,6 +182,7 @@ end Returns the estimated probabilities for a Conformal Classifier. """ function Models.probs(M::ConformalModel, X::AbstractArray) + if M.likelihood == :classification_binary output = σ.(Models.logits(M, X)) elseif M.likelihood == :classification_multi -- GitLab