Skip to content
Snippets Groups Projects
Commit a933cc9b authored by Pat Alt's avatar Pat Alt
Browse files

LeNet implemented

parent aed8500d
No related branches found
No related tags found
1 merge request!4336 rebuttal
No preview for this file type
No preview for this file type
...@@ -388,11 +388,6 @@ function _plot_eccco_mnist( ...@@ -388,11 +388,6 @@ function _plot_eccco_mnist(
ces = Dict() ces = Dict()
for (mod_name, mod) in model_dict for (mod_name, mod) in model_dict
if mod.model.model isa ProbabilisticPipeline
x = mod.model.model.f(x)
_data = counterfactual_data
_data.X = mod.model.model.f(_data.X)
end
ce = generate_counterfactual( ce = generate_counterfactual(
x, target, counterfactual_data, mod, eccco_generator; x, target, counterfactual_data, mod, eccco_generator;
decision_threshold=γ, max_iter=T, decision_threshold=γ, max_iter=T,
...@@ -455,7 +450,11 @@ mutable struct LeNetBuilder ...@@ -455,7 +450,11 @@ mutable struct LeNetBuilder
channels2::Int channels2::Int
end end
function MLJFlux.build(b::LeNetBuilder, rng, n_in, n_out, n_channels) preproc(X) = reshape(X, (28, 28, 1, :))
function MLJFlux.build(b::LeNetBuilder, rng, n_in, n_out)
_n_in = Int(sqrt(n_in))
k, c1, c2 = b.filter_size, b.channels1, b.channels2 k, c1, c2 = b.filter_size, b.channels1, b.channels2
...@@ -464,42 +463,35 @@ function MLJFlux.build(b::LeNetBuilder, rng, n_in, n_out, n_channels) ...@@ -464,42 +463,35 @@ function MLJFlux.build(b::LeNetBuilder, rng, n_in, n_out, n_channels)
# padding to preserve image size on convolution: # padding to preserve image size on convolution:
p = div(k - 1, 2) p = div(k - 1, 2)
preproc(x) = reshape(x, (_n_in, _n_in, 1, :))
front = Flux.Chain( front = Flux.Chain(
Conv((k, k), n_channels => c1, pad=(p, p), relu), Conv((k, k), 1 => c1, pad=(p, p), relu),
MaxPool((2, 2)), MaxPool((2, 2)),
Conv((k, k), c1 => c2, pad=(p, p), relu), Conv((k, k), c1 => c2, pad=(p, p), relu),
MaxPool((2, 2)), MaxPool((2, 2)),
Flux.flatten Flux.flatten
) )
d = Flux.outputsize(front, (n_in..., n_channels, 1)) |> first d = Flux.outputsize(front, (_n_in, _n_in, 1, 1)) |> first
back = Flux.Chain( back = Flux.Chain(
Dense(d, 120, relu), Dense(d, 120, relu),
Dense(120, 84, relu), Dense(120, 84, relu),
Dense(84, n_out), Dense(84, n_out),
) )
chain = Flux.Chain(front, back) chain = Flux.Chain(preproc, front, back)
return chain return chain
end end
# Final model: # Final model:
lenet = ImageClassifier( lenet = NeuralNetworkClassifier(
builder=LeNetBuilder(5, 6, 16), builder=LeNetBuilder(5, 6, 16),
epochs=epochs, epochs=epochs,
batch_size=batch_size, batch_size=batch_size,
finaliser=_finaliser, finaliser=_finaliser,
loss=_loss, loss=_loss,
) )
# Convert to image:
function toimg(X)
X = matrix(X) |> x -> reshape(x, 28, 28, :)
X = coerce(X, GrayImage)
return X
end
lenet = (X -> toimg(X)) |> lenet
``` ```
```{julia} ```{julia}
...@@ -508,8 +500,12 @@ add_retrain = true ...@@ -508,8 +500,12 @@ add_retrain = true
# Deep Ensemble: # Deep Ensemble:
mlp_large_ens = EnsembleModel(model=mlp, n=50) mlp_large_ens = EnsembleModel(model=mlp, n=50)
# LeNet-5 Ensemble:
lenet_ens = EnsembleModel(model=lenet, n=5)
add_models = Dict( add_models = Dict(
"LeNet-5" => lenet, "LeNet-5" => lenet,
"LeNet-5 Ensemble" => lenet_ens,
"Large Ensemble (n=50)" => mlp_large_ens, "Large Ensemble (n=50)" => mlp_large_ens,
) )
...@@ -528,6 +524,7 @@ _plt_order = [ ...@@ -528,6 +524,7 @@ _plt_order = [
"MLP Ensemble", "MLP Ensemble",
"Large Ensemble (n=50)", "Large Ensemble (n=50)",
"LeNet-5", "LeNet-5",
"LeNet-5 Ensemble",
"JEM", "JEM",
"JEM Ensemble", "JEM Ensemble",
] ]
......
using ChainRules: ignore_derivatives
using ConformalPrediction using ConformalPrediction
using CounterfactualExplanations.Models using CounterfactualExplanations.Models
using Flux using Flux
...@@ -38,13 +39,18 @@ end ...@@ -38,13 +39,18 @@ end
Private function that extracts the chains from a fitted model. Private function that extracts the chains from a fitted model.
""" """
function _get_chains(fitresult) function _get_chains(fitresult)
if fitresult isa MLJEnsembles.WrappedEnsemble
chains = map(res -> res[1], fitresult.ensemble) chains = []
elseif fitresult isa MLJBase.Signature
chains = [fitted_params(fitresult)[:image_classifier][1]] # for piped image classifiers ignore_derivatives() do
else if fitresult isa MLJEnsembles.WrappedEnsemble
chains = [fitresult[1]] _chains = map(res -> res[1], fitresult.ensemble)
else
_chains = [fitresult[1]]
end
push!(chains, _chains...)
end end
return chains return chains
end end
...@@ -147,7 +153,9 @@ In the binary case logits are fed through the sigmoid function instead of softma ...@@ -147,7 +153,9 @@ In the binary case logits are fed through the sigmoid function instead of softma
which follows from the derivation here: https://stats.stackexchange.com/questions/233658/softmax-vs-sigmoid-function-in-logistic-classifier which follows from the derivation here: https://stats.stackexchange.com/questions/233658/softmax-vs-sigmoid-function-in-logistic-classifier
""" """
function Models.logits(M::ConformalModel, X::AbstractArray) function Models.logits(M::ConformalModel, X::AbstractArray)
fitresult = M.fitresult fitresult = M.fitresult
function predict_logits(fitresult, x) function predict_logits(fitresult, x)
= MLUtils.stack(map(chain -> get_logits(chain,x),_get_chains(fitresult))) |> = MLUtils.stack(map(chain -> get_logits(chain,x),_get_chains(fitresult))) |>
y -> mean(y, dims=ndims(y)) |> y -> mean(y, dims=ndims(y)) |>
...@@ -162,14 +170,9 @@ function Models.logits(M::ConformalModel, X::AbstractArray) ...@@ -162,14 +170,9 @@ function Models.logits(M::ConformalModel, X::AbstractArray)
= ndims() > 1 ? : permutedims([]) = ndims() > 1 ? : permutedims([])
return return
end end
if ndims(X) > 2
yhat = map(eachslice(X, dims=ndims(X))) do x yhat = predict_logits(fitresult, X)
predict_logits(fitresult, x)
end
yhat = MLUtils.stack(yhat)
else
yhat = predict_logits(fitresult, X)
end
return yhat return yhat
end end
...@@ -179,6 +182,7 @@ end ...@@ -179,6 +182,7 @@ end
Returns the estimated probabilities for a Conformal Classifier. Returns the estimated probabilities for a Conformal Classifier.
""" """
function Models.probs(M::ConformalModel, X::AbstractArray) function Models.probs(M::ConformalModel, X::AbstractArray)
if M.likelihood == :classification_binary if M.likelihood == :classification_binary
output = σ.(Models.logits(M, X)) output = σ.(Models.logits(M, X))
elseif M.likelihood == :classification_multi elseif M.likelihood == :classification_multi
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment