Skip to content
Snippets Groups Projects
Commit cfb04f76 authored by Pat Alt's avatar Pat Alt
Browse files

robustnn not as good as expected unfortunately

parent 923a480b
No related branches found
No related tags found
1 merge request!4336 rebuttal
No preview for this file type
No preview for this file type
No preview for this file type
......@@ -475,16 +475,16 @@ function MLJFlux.build(b::LeNetBuilder, rng, n_in, n_out)
preproc(x) = reshape(x, (_n_in, _n_in, 1, :))
front = Flux.Chain(
Conv((k, k), 1 => c1, pad=(p, p), relu),
Conv((k, k), 1 => c1, pad=(p, p), sigmoid),
MaxPool((2, 2)),
Conv((k, k), c1 => c2, pad=(p, p), relu),
Conv((k, k), c1 => c2, pad=(p, p), sigmoid),
MaxPool((2, 2)),
Flux.flatten
)
d = Flux.outputsize(front, (_n_in, _n_in, 1, 1)) |> first
back = Flux.Chain(
Dense(d, 120, relu),
Dense(120, 84, relu),
Dense(d, 120, sigmoid),
Dense(120, 84, sigmoid),
Dense(84, n_out),
)
......@@ -537,9 +537,13 @@ add_retrain = true
# Deep Ensemble:
mlp_large_ens = EnsembleModel(model=mlp, n=50)
# CNN Ensemble:
lenet_ens = EnsembleModel(model=lenet, n=5)
add_models = Dict(
"LeNet-5" => lenet,
"RobustNet" => rob_net,
"LeNet-5 Ensemble" => lenet_ens,
# "RobustNet" => rob_net,
"Large Ensemble (n=50)" => mlp_large_ens,
)
......@@ -581,7 +585,8 @@ _plt_order = [
"MLP Ensemble",
"Large Ensemble (n=50)",
"LeNet-5",
"RobustNet",
"LeNet-5 Ensemble",
# "RobustNet",
"JEM",
"JEM Ensemble",
]
......@@ -608,10 +613,9 @@ plts = [plt_additional_models]
for (factual, target) in combos
plt, _, _ = _plot_eccco_mnist(
factual, target;
plt_order = _plt_order,
model_dict = large_model_dict,
x = factual,
target = target,
wide = true,
plot_factual = true,
img_height = 150
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment