diff --git a/artifacts/results/mnist_vae.jls b/artifacts/results/mnist_vae.jls index a55474be7a30c77ab1b046ad0c2794523c514d94..872e5a43473de969e478c182fe76d324da88d0dc 100644 Binary files a/artifacts/results/mnist_vae.jls and b/artifacts/results/mnist_vae.jls differ diff --git a/artifacts/results/mnist_vae_weak.jls b/artifacts/results/mnist_vae_weak.jls index 8ed61727f0b7c3e5775110beeca6df7f9c9d10ca..42d8d05d0113c303f09bad12ce13865f76cff0cb 100644 Binary files a/artifacts/results/mnist_vae_weak.jls and b/artifacts/results/mnist_vae_weak.jls differ diff --git a/notebooks/mnist.qmd b/notebooks/mnist.qmd index feedfb983c656e54325ec8ae822e7ddb7bd7cc7b..3fa961e0a17f40fbc112bee4c215137f186f49b1 100644 --- a/notebooks/mnist.qmd +++ b/notebooks/mnist.qmd @@ -388,6 +388,11 @@ function _plot_eccco_mnist( ces = Dict() for (mod_name, mod) in model_dict + if mod.model.model isa ProbabilisticPipeline + x = mod.model.model.f(x) + _data = counterfactual_data + _data.X = mod.model.model.f(_data.X) + end ce = generate_counterfactual( x, target, counterfactual_data, mod, eccco_generator; decision_threshold=γ, max_iter=T, @@ -460,18 +465,22 @@ function MLJFlux.build(b::LeNetBuilder, rng, n_in, n_out, n_channels) p = div(k - 1, 2) front = Flux.Chain( - Conv((k, k), n_channels => c1, pad=(p, p), relu), - MaxPool((2, 2)), - Conv((k, k), c1 => c2, pad=(p, p), relu), - MaxPool((2, 2)), - Flux.flatten) + Conv((k, k), n_channels => c1, pad=(p, p), relu), + MaxPool((2, 2)), + Conv((k, k), c1 => c2, pad=(p, p), relu), + MaxPool((2, 2)), + Flux.flatten + ) d = Flux.outputsize(front, (n_in..., n_channels, 1)) |> first back = Flux.Chain( - Dense(256, 120, relu), + Dense(d, 120, relu), Dense(120, 84, relu), Dense(84, n_out), ) - return Flux.Chain(front, back) + + chain = Flux.Chain(front, back) + + return chain end # Final model: @@ -500,8 +509,8 @@ add_retrain = true mlp_large_ens = EnsembleModel(model=mlp, n=50) add_models = Dict( - "Large Ensemble" => mlp_large_ens, "LeNet-5" => lenet, + "Large Ensemble (n=50)" => mlp_large_ens, ) if add_retrain @@ -517,7 +526,8 @@ end _plt_order = [ "MLP", "MLP Ensemble", - "Large Ensemble", + "Large Ensemble (n=50)", + "LeNet-5", "JEM", "JEM Ensemble", ] diff --git a/src/model.jl b/src/model.jl index 57fcf48a2a1e38ce81712c6a7eae3cdf5c223491..bf3dbe8a49dc235f0209f94f428f555e398493ad 100644 --- a/src/model.jl +++ b/src/model.jl @@ -40,6 +40,8 @@ Private function that extracts the chains from a fitted model. function _get_chains(fitresult) if fitresult isa MLJEnsembles.WrappedEnsemble chains = map(res -> res[1], fitresult.ensemble) + elseif fitresult isa MLJBase.Signature + chains = [fitted_params(fitresult)[:image_classifier][1]] # for piped image classifiers else chains = [fitresult[1]] end