From aed8500d6888add3834483a7d8a4b05cc39a865d Mon Sep 17 00:00:00 2001
From: Pat Alt <55311242+pat-alt@users.noreply.github.com>
Date: Tue, 8 Aug 2023 11:40:07 +0200
Subject: [PATCH] issues with formatting for MLJFlux

---
 artifacts/results/mnist_vae.jls      | Bin 1003682 -> 1003682 bytes
 artifacts/results/mnist_vae_weak.jls | Bin 105130 -> 105130 bytes
 notebooks/mnist.qmd                  |  28 ++++++++++++++++++---------
 src/model.jl                         |   2 ++
 4 files changed, 21 insertions(+), 9 deletions(-)

diff --git a/artifacts/results/mnist_vae.jls b/artifacts/results/mnist_vae.jls
index a55474be7a30c77ab1b046ad0c2794523c514d94..872e5a43473de969e478c182fe76d324da88d0dc 100644
GIT binary patch
delta 113
zcmZ3q(00*6+lCg#7N!>F7M3lnUNYMQzwqs3+|K=-@4fc+6}S0TN^CC><YHvp&h?Ft
zk+Ho{hzp3hftUw~dAAn|@tL38zFCF;o%VK58P*Tl+kYwX?__MZ^k4;IHXvpPVvg;W
K9-L+;8GQjkaVKX0

delta 113
zcmZ3q(00*6+lCg#7N!>F7M3lnUNYMcM{>Q_ZqIJ#0%C3;<^f{f?b+>oW+%7LN@8VV
zY@eIL3dC$c%nrmH+vlcmnxEYM$&+=Z#C9(k)}4&o=OnW-GHz#<V|}l^eN!tJP~?g(
J*E?-SUjWIUD^dUe

diff --git a/artifacts/results/mnist_vae_weak.jls b/artifacts/results/mnist_vae_weak.jls
index 8ed61727f0b7c3e5775110beeca6df7f9c9d10ca..42d8d05d0113c303f09bad12ce13865f76cff0cb 100644
GIT binary patch
delta 50
zcmV-20L}lZwFauS2Czzsv!jf!dXs63F_ZF&catoPcautrF_W5$lDAR^0jU_1dy03r
I`yT<XdV8iA7ytkO

delta 56
zcmZ3rm2K5lwhdlg(<Rs#4X3xrGnP-+Vg(XES%6$mHpWTQ7XrDPy}H&{ZNJUNxKd*K
JZ+XV`RRBh56(0Zq

diff --git a/notebooks/mnist.qmd b/notebooks/mnist.qmd
index feedfb98..3fa961e0 100644
--- a/notebooks/mnist.qmd
+++ b/notebooks/mnist.qmd
@@ -388,6 +388,11 @@ function _plot_eccco_mnist(
 
     ces = Dict()
     for (mod_name, mod) in model_dict
+        if mod.model.model isa ProbabilisticPipeline
+            x = mod.model.model.f(x)
+            _data = counterfactual_data
+            _data.X = mod.model.model.f(_data.X)
+        end
         ce = generate_counterfactual(
             x, target, counterfactual_data, mod, eccco_generator; 
             decision_threshold=γ, max_iter=T,
@@ -460,18 +465,22 @@ function MLJFlux.build(b::LeNetBuilder, rng, n_in, n_out, n_channels)
 	p = div(k - 1, 2)
 
 	front = Flux.Chain(
-			   Conv((k, k), n_channels => c1, pad=(p, p), relu),
-			   MaxPool((2, 2)),
-			   Conv((k, k), c1 => c2, pad=(p, p), relu),
-			   MaxPool((2, 2)),
-			   Flux.flatten)
+        Conv((k, k), n_channels => c1, pad=(p, p), relu),
+        MaxPool((2, 2)),
+        Conv((k, k), c1 => c2, pad=(p, p), relu),
+        MaxPool((2, 2)),
+        Flux.flatten
+    )
 	d = Flux.outputsize(front, (n_in..., n_channels, 1)) |> first
     back = Flux.Chain(
-        Dense(256, 120, relu),
+        Dense(d, 120, relu),
         Dense(120, 84, relu),
         Dense(84, n_out),
     )
-	return Flux.Chain(front, back)
+
+    chain = Flux.Chain(front, back)
+
+	return chain
 end
 
 # Final model:
@@ -500,8 +509,8 @@ add_retrain = true
 mlp_large_ens = EnsembleModel(model=mlp, n=50)
 
 add_models = Dict(
-    "Large Ensemble" => mlp_large_ens,
     "LeNet-5" => lenet,
+    "Large Ensemble (n=50)" => mlp_large_ens,
 )
 
 if add_retrain
@@ -517,7 +526,8 @@ end
 _plt_order = [
     "MLP", 
     "MLP Ensemble", 
-    "Large Ensemble", 
+    "Large Ensemble (n=50)", 
+    "LeNet-5",
     "JEM", 
     "JEM Ensemble",
 ]
diff --git a/src/model.jl b/src/model.jl
index 57fcf48a..bf3dbe8a 100644
--- a/src/model.jl
+++ b/src/model.jl
@@ -40,6 +40,8 @@ Private function that extracts the chains from a fitted model.
 function _get_chains(fitresult)
     if fitresult isa MLJEnsembles.WrappedEnsemble
         chains = map(res -> res[1], fitresult.ensemble)
+    elseif fitresult isa MLJBase.Signature
+        chains = [fitted_params(fitresult)[:image_classifier][1]]      # for piped image classifiers
     else
         chains = [fitresult[1]]
     end
-- 
GitLab