Skip to content
Snippets Groups Projects
Commit aed8500d authored by Pat Alt's avatar Pat Alt
Browse files

issues with formatting for MLJFlux

parent 0f935942
No related branches found
No related tags found
1 merge request!4336 rebuttal
No preview for this file type
No preview for this file type
......@@ -388,6 +388,11 @@ function _plot_eccco_mnist(
ces = Dict()
for (mod_name, mod) in model_dict
if mod.model.model isa ProbabilisticPipeline
x = mod.model.model.f(x)
_data = counterfactual_data
_data.X = mod.model.model.f(_data.X)
end
ce = generate_counterfactual(
x, target, counterfactual_data, mod, eccco_generator;
decision_threshold=γ, max_iter=T,
......@@ -460,18 +465,22 @@ function MLJFlux.build(b::LeNetBuilder, rng, n_in, n_out, n_channels)
p = div(k - 1, 2)
front = Flux.Chain(
Conv((k, k), n_channels => c1, pad=(p, p), relu),
MaxPool((2, 2)),
Conv((k, k), c1 => c2, pad=(p, p), relu),
MaxPool((2, 2)),
Flux.flatten)
Conv((k, k), n_channels => c1, pad=(p, p), relu),
MaxPool((2, 2)),
Conv((k, k), c1 => c2, pad=(p, p), relu),
MaxPool((2, 2)),
Flux.flatten
)
d = Flux.outputsize(front, (n_in..., n_channels, 1)) |> first
back = Flux.Chain(
Dense(256, 120, relu),
Dense(d, 120, relu),
Dense(120, 84, relu),
Dense(84, n_out),
)
return Flux.Chain(front, back)
chain = Flux.Chain(front, back)
return chain
end
# Final model:
......@@ -500,8 +509,8 @@ add_retrain = true
mlp_large_ens = EnsembleModel(model=mlp, n=50)
add_models = Dict(
"Large Ensemble" => mlp_large_ens,
"LeNet-5" => lenet,
"Large Ensemble (n=50)" => mlp_large_ens,
)
if add_retrain
......@@ -517,7 +526,8 @@ end
_plt_order = [
"MLP",
"MLP Ensemble",
"Large Ensemble",
"Large Ensemble (n=50)",
"LeNet-5",
"JEM",
"JEM Ensemble",
]
......
......@@ -40,6 +40,8 @@ Private function that extracts the chains from a fitted model.
function _get_chains(fitresult)
if fitresult isa MLJEnsembles.WrappedEnsemble
chains = map(res -> res[1], fitresult.ensemble)
elseif fitresult isa MLJBase.Signature
chains = [fitted_params(fitresult)[:image_classifier][1]] # for piped image classifiers
else
chains = [fitresult[1]]
end
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment