From a933cc9bb2f40a313660ebefd37c6aa55d828c67 Mon Sep 17 00:00:00 2001
From: Pat Alt <55311242+pat-alt@users.noreply.github.com>
Date: Tue, 8 Aug 2023 14:10:45 +0200
Subject: [PATCH] LeNet implemented

---
 artifacts/results/mnist_vae.jls      | Bin 1003682 -> 1003682 bytes
 artifacts/results/mnist_vae_weak.jls | Bin 105130 -> 105130 bytes
 notebooks/mnist.qmd                  |  35 ++++++++++++---------------
 src/model.jl                         |  32 +++++++++++++-----------
 4 files changed, 34 insertions(+), 33 deletions(-)

diff --git a/artifacts/results/mnist_vae.jls b/artifacts/results/mnist_vae.jls
index 872e5a43473de969e478c182fe76d324da88d0dc..0eed13cca664e891d5f09c05ef2f341b4fe67efd 100644
GIT binary patch
delta 113
zcmZ3q(00*6+lCg#7N!>F7M3lnUNYM^wQ}ub-2Tavb*04iE4p0owYPi8umVL6M}pY1
zlUNxU+vldR0x=s9vjZ{5_PHsX<|o^;+qr<48;E&;n0I@2JD=If?X#0vnHaY-%dx)G
HX7mLB{PQbP

delta 113
zcmZ3q(00*6+lCg#7N!>F7M3lnUNYNfC$a8i+`d_b|DE=BCOOts65E|8ure}k-=xgX
z$k^^Okrjy9fS4VKIkvk@<TO9IeT^CGd+qHZDIn5I2B^QiI*JR3xq+Amh<UeHNAa1R
HWb_395Bex|

diff --git a/artifacts/results/mnist_vae_weak.jls b/artifacts/results/mnist_vae_weak.jls
index 42d8d05d0113c303f09bad12ce13865f76cff0cb..d1c8752d6b813e87b8ae1e8c82bf6125a31a30d7 100644
GIT binary patch
delta 63
zcmZ3rm2K5lwhdlg(-*QbmT&gzT3<EYla0}Hx)zYT{WcrpN{Q{i<r&viP5;5dXf$1d
TjnQy=i#%ib^hs=t9*n*KAH5cv

delta 62
zcmZ3rm2K5lwhdlgn-6xaubSR5lhJVcPFBY9$zEOM(|3Y+1+0vo(-UVgdTiHcV_YRM
S{nbpy^6fX}7}r%X`T_uBf*HL4

diff --git a/notebooks/mnist.qmd b/notebooks/mnist.qmd
index 3fa961e0..b5b68d58 100644
--- a/notebooks/mnist.qmd
+++ b/notebooks/mnist.qmd
@@ -388,11 +388,6 @@ function _plot_eccco_mnist(
 
     ces = Dict()
     for (mod_name, mod) in model_dict
-        if mod.model.model isa ProbabilisticPipeline
-            x = mod.model.model.f(x)
-            _data = counterfactual_data
-            _data.X = mod.model.model.f(_data.X)
-        end
         ce = generate_counterfactual(
             x, target, counterfactual_data, mod, eccco_generator; 
             decision_threshold=γ, max_iter=T,
@@ -455,7 +450,11 @@ mutable struct LeNetBuilder
 	channels2::Int
 end
 
-function MLJFlux.build(b::LeNetBuilder, rng, n_in, n_out, n_channels)
+preproc(X) = reshape(X, (28, 28, 1, :))
+
+function MLJFlux.build(b::LeNetBuilder, rng, n_in, n_out)
+
+    _n_in = Int(sqrt(n_in))
 
 	k, c1, c2 = b.filter_size, b.channels1, b.channels2
 
@@ -464,42 +463,35 @@ function MLJFlux.build(b::LeNetBuilder, rng, n_in, n_out, n_channels)
 	# padding to preserve image size on convolution:
 	p = div(k - 1, 2)
 
+    preproc(x) = reshape(x, (_n_in, _n_in, 1, :))
+
 	front = Flux.Chain(
-        Conv((k, k), n_channels => c1, pad=(p, p), relu),
+        Conv((k, k), 1 => c1, pad=(p, p), relu),
         MaxPool((2, 2)),
         Conv((k, k), c1 => c2, pad=(p, p), relu),
         MaxPool((2, 2)),
         Flux.flatten
     )
-	d = Flux.outputsize(front, (n_in..., n_channels, 1)) |> first
+	d = Flux.outputsize(front, (_n_in, _n_in, 1, 1)) |> first
     back = Flux.Chain(
         Dense(d, 120, relu),
         Dense(120, 84, relu),
         Dense(84, n_out),
     )
 
-    chain = Flux.Chain(front, back)
+    chain = Flux.Chain(preproc, front, back)
 
 	return chain
 end
 
 # Final model:
-lenet = ImageClassifier(
+lenet = NeuralNetworkClassifier(
     builder=LeNetBuilder(5, 6, 16),
     epochs=epochs,
     batch_size=batch_size,
     finaliser=_finaliser,
     loss=_loss,
 )
-
-# Convert to image:
-function toimg(X)
-    X = matrix(X) |> x -> reshape(x, 28, 28, :)
-    X = coerce(X, GrayImage)
-    return X
-end
-
-lenet = (X -> toimg(X)) |> lenet
 ```
 
 ```{julia}
@@ -508,8 +500,12 @@ add_retrain = true
 # Deep Ensemble:
 mlp_large_ens = EnsembleModel(model=mlp, n=50)
 
+# LeNet-5 Ensemble:
+lenet_ens = EnsembleModel(model=lenet, n=5)
+
 add_models = Dict(
     "LeNet-5" => lenet,
+    "LeNet-5 Ensemble" => lenet_ens,
     "Large Ensemble (n=50)" => mlp_large_ens,
 )
 
@@ -528,6 +524,7 @@ _plt_order = [
     "MLP Ensemble", 
     "Large Ensemble (n=50)", 
     "LeNet-5",
+    "LeNet-5 Ensemble",
     "JEM", 
     "JEM Ensemble",
 ]
diff --git a/src/model.jl b/src/model.jl
index bf3dbe8a..a865d559 100644
--- a/src/model.jl
+++ b/src/model.jl
@@ -1,3 +1,4 @@
+using ChainRules: ignore_derivatives
 using ConformalPrediction
 using CounterfactualExplanations.Models
 using Flux
@@ -38,13 +39,18 @@ end
 Private function that extracts the chains from a fitted model.
 """
 function _get_chains(fitresult)
-    if fitresult isa MLJEnsembles.WrappedEnsemble
-        chains = map(res -> res[1], fitresult.ensemble)
-    elseif fitresult isa MLJBase.Signature
-        chains = [fitted_params(fitresult)[:image_classifier][1]]      # for piped image classifiers
-    else
-        chains = [fitresult[1]]
+    
+    chains = []
+
+    ignore_derivatives() do 
+        if fitresult isa MLJEnsembles.WrappedEnsemble
+            _chains = map(res -> res[1], fitresult.ensemble)
+        else
+            _chains = [fitresult[1]]
+        end
+        push!(chains, _chains...)
     end
+    
     return chains
 end
 
@@ -147,7 +153,9 @@ In the binary case logits are fed through the sigmoid function instead of softma
 which follows from the derivation here: https://stats.stackexchange.com/questions/233658/softmax-vs-sigmoid-function-in-logistic-classifier
 """
 function Models.logits(M::ConformalModel, X::AbstractArray)
+    
     fitresult = M.fitresult
+
     function predict_logits(fitresult, x)
         ŷ = MLUtils.stack(map(chain -> get_logits(chain,x),_get_chains(fitresult))) |> 
             y -> mean(y, dims=ndims(y)) |>
@@ -162,14 +170,9 @@ function Models.logits(M::ConformalModel, X::AbstractArray)
         ŷ = ndims(ŷ) > 1 ? ŷ : permutedims([ŷ])
         return ŷ
     end
-    if ndims(X) > 2
-        yhat = map(eachslice(X, dims=ndims(X))) do x
-            predict_logits(fitresult, x)
-        end
-        yhat = MLUtils.stack(yhat)
-    else
-        yhat = predict_logits(fitresult, X)
-    end
+
+    yhat = predict_logits(fitresult, X)
+
     return yhat
 end
 
@@ -179,6 +182,7 @@ end
 Returns the estimated probabilities for a Conformal Classifier.
 """
 function Models.probs(M::ConformalModel, X::AbstractArray)
+
     if M.likelihood == :classification_binary
         output = σ.(Models.logits(M, X))
     elseif M.likelihood == :classification_multi
-- 
GitLab