Skip to content
Snippets Groups Projects
Commit aed8500d authored by Pat Alt's avatar Pat Alt
Browse files

issues with formatting for MLJFlux

parent 0f935942
No related branches found
No related tags found
1 merge request!4336 rebuttal
No preview for this file type
No preview for this file type
...@@ -388,6 +388,11 @@ function _plot_eccco_mnist( ...@@ -388,6 +388,11 @@ function _plot_eccco_mnist(
ces = Dict() ces = Dict()
for (mod_name, mod) in model_dict for (mod_name, mod) in model_dict
if mod.model.model isa ProbabilisticPipeline
x = mod.model.model.f(x)
_data = counterfactual_data
_data.X = mod.model.model.f(_data.X)
end
ce = generate_counterfactual( ce = generate_counterfactual(
x, target, counterfactual_data, mod, eccco_generator; x, target, counterfactual_data, mod, eccco_generator;
decision_threshold=γ, max_iter=T, decision_threshold=γ, max_iter=T,
...@@ -460,18 +465,22 @@ function MLJFlux.build(b::LeNetBuilder, rng, n_in, n_out, n_channels) ...@@ -460,18 +465,22 @@ function MLJFlux.build(b::LeNetBuilder, rng, n_in, n_out, n_channels)
p = div(k - 1, 2) p = div(k - 1, 2)
front = Flux.Chain( front = Flux.Chain(
Conv((k, k), n_channels => c1, pad=(p, p), relu), Conv((k, k), n_channels => c1, pad=(p, p), relu),
MaxPool((2, 2)), MaxPool((2, 2)),
Conv((k, k), c1 => c2, pad=(p, p), relu), Conv((k, k), c1 => c2, pad=(p, p), relu),
MaxPool((2, 2)), MaxPool((2, 2)),
Flux.flatten) Flux.flatten
)
d = Flux.outputsize(front, (n_in..., n_channels, 1)) |> first d = Flux.outputsize(front, (n_in..., n_channels, 1)) |> first
back = Flux.Chain( back = Flux.Chain(
Dense(256, 120, relu), Dense(d, 120, relu),
Dense(120, 84, relu), Dense(120, 84, relu),
Dense(84, n_out), Dense(84, n_out),
) )
return Flux.Chain(front, back)
chain = Flux.Chain(front, back)
return chain
end end
# Final model: # Final model:
...@@ -500,8 +509,8 @@ add_retrain = true ...@@ -500,8 +509,8 @@ add_retrain = true
mlp_large_ens = EnsembleModel(model=mlp, n=50) mlp_large_ens = EnsembleModel(model=mlp, n=50)
add_models = Dict( add_models = Dict(
"Large Ensemble" => mlp_large_ens,
"LeNet-5" => lenet, "LeNet-5" => lenet,
"Large Ensemble (n=50)" => mlp_large_ens,
) )
if add_retrain if add_retrain
...@@ -517,7 +526,8 @@ end ...@@ -517,7 +526,8 @@ end
_plt_order = [ _plt_order = [
"MLP", "MLP",
"MLP Ensemble", "MLP Ensemble",
"Large Ensemble", "Large Ensemble (n=50)",
"LeNet-5",
"JEM", "JEM",
"JEM Ensemble", "JEM Ensemble",
] ]
......
...@@ -40,6 +40,8 @@ Private function that extracts the chains from a fitted model. ...@@ -40,6 +40,8 @@ Private function that extracts the chains from a fitted model.
function _get_chains(fitresult) function _get_chains(fitresult)
if fitresult isa MLJEnsembles.WrappedEnsemble if fitresult isa MLJEnsembles.WrappedEnsemble
chains = map(res -> res[1], fitresult.ensemble) chains = map(res -> res[1], fitresult.ensemble)
elseif fitresult isa MLJBase.Signature
chains = [fitted_params(fitresult)[:image_classifier][1]] # for piped image classifiers
else else
chains = [fitresult[1]] chains = [fitresult[1]]
end end
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment