From aed8500d6888add3834483a7d8a4b05cc39a865d Mon Sep 17 00:00:00 2001 From: Pat Alt <55311242+pat-alt@users.noreply.github.com> Date: Tue, 8 Aug 2023 11:40:07 +0200 Subject: [PATCH] issues with formatting for MLJFlux --- artifacts/results/mnist_vae.jls | Bin 1003682 -> 1003682 bytes artifacts/results/mnist_vae_weak.jls | Bin 105130 -> 105130 bytes notebooks/mnist.qmd | 28 ++++++++++++++++++--------- src/model.jl | 2 ++ 4 files changed, 21 insertions(+), 9 deletions(-) diff --git a/artifacts/results/mnist_vae.jls b/artifacts/results/mnist_vae.jls index a55474be7a30c77ab1b046ad0c2794523c514d94..872e5a43473de969e478c182fe76d324da88d0dc 100644 GIT binary patch delta 113 zcmZ3q(00*6+lCg#7N!>F7M3lnUNYMQzwqs3+|K=-@4fc+6}S0TN^CC><YHvp&h?Ft zk+Ho{hzp3hftUw~dAAn|@tL38zFCF;o%VK58P*Tl+kYwX?__MZ^k4;IHXvpPVvg;W K9-L+;8GQjkaVKX0 delta 113 zcmZ3q(00*6+lCg#7N!>F7M3lnUNYMcM{>Q_ZqIJ#0%C3;<^f{f?b+>oW+%7LN@8VV zY@eIL3dC$c%nrmH+vlcmnxEYM$&+=Z#C9(k)}4&o=OnW-GHz#<V|}l^eN!tJP~?g( J*E?-SUjWIUD^dUe diff --git a/artifacts/results/mnist_vae_weak.jls b/artifacts/results/mnist_vae_weak.jls index 8ed61727f0b7c3e5775110beeca6df7f9c9d10ca..42d8d05d0113c303f09bad12ce13865f76cff0cb 100644 GIT binary patch delta 50 zcmV-20L}lZwFauS2Czzsv!jf!dXs63F_ZF&catoPcautrF_W5$lDAR^0jU_1dy03r I`yT<XdV8iA7ytkO delta 56 zcmZ3rm2K5lwhdlg(<Rs#4X3xrGnP-+Vg(XES%6$mHpWTQ7XrDPy}H&{ZNJUNxKd*K JZ+XV`RRBh56(0Zq diff --git a/notebooks/mnist.qmd b/notebooks/mnist.qmd index feedfb98..3fa961e0 100644 --- a/notebooks/mnist.qmd +++ b/notebooks/mnist.qmd @@ -388,6 +388,11 @@ function _plot_eccco_mnist( ces = Dict() for (mod_name, mod) in model_dict + if mod.model.model isa ProbabilisticPipeline + x = mod.model.model.f(x) + _data = counterfactual_data + _data.X = mod.model.model.f(_data.X) + end ce = generate_counterfactual( x, target, counterfactual_data, mod, eccco_generator; decision_threshold=γ, max_iter=T, @@ -460,18 +465,22 @@ function MLJFlux.build(b::LeNetBuilder, rng, n_in, n_out, n_channels) p = div(k - 1, 2) front = Flux.Chain( - Conv((k, k), n_channels => c1, pad=(p, p), relu), - MaxPool((2, 2)), - Conv((k, k), c1 => c2, pad=(p, p), relu), - MaxPool((2, 2)), - Flux.flatten) + Conv((k, k), n_channels => c1, pad=(p, p), relu), + MaxPool((2, 2)), + Conv((k, k), c1 => c2, pad=(p, p), relu), + MaxPool((2, 2)), + Flux.flatten + ) d = Flux.outputsize(front, (n_in..., n_channels, 1)) |> first back = Flux.Chain( - Dense(256, 120, relu), + Dense(d, 120, relu), Dense(120, 84, relu), Dense(84, n_out), ) - return Flux.Chain(front, back) + + chain = Flux.Chain(front, back) + + return chain end # Final model: @@ -500,8 +509,8 @@ add_retrain = true mlp_large_ens = EnsembleModel(model=mlp, n=50) add_models = Dict( - "Large Ensemble" => mlp_large_ens, "LeNet-5" => lenet, + "Large Ensemble (n=50)" => mlp_large_ens, ) if add_retrain @@ -517,7 +526,8 @@ end _plt_order = [ "MLP", "MLP Ensemble", - "Large Ensemble", + "Large Ensemble (n=50)", + "LeNet-5", "JEM", "JEM Ensemble", ] diff --git a/src/model.jl b/src/model.jl index 57fcf48a..bf3dbe8a 100644 --- a/src/model.jl +++ b/src/model.jl @@ -40,6 +40,8 @@ Private function that extracts the chains from a fitted model. function _get_chains(fitresult) if fitresult isa MLJEnsembles.WrappedEnsemble chains = map(res -> res[1], fitresult.ensemble) + elseif fitresult isa MLJBase.Signature + chains = [fitted_params(fitresult)[:image_classifier][1]] # for piped image classifiers else chains = [fitresult[1]] end -- GitLab