Skip to content
Snippets Groups Projects
Commit 923a480b authored by Pat Alt's avatar Pat Alt
Browse files

all coding stuff done

parent b34abef1
No related branches found
No related tags found
1 merge request!4336 rebuttal
No preview for this file type
No preview for this file type
No preview for this file type
......@@ -165,7 +165,7 @@ end
```{julia}
# Hyper:
_retrain = false
_regen = true
_regen = false
# Data:
n_obs = 10000
......@@ -369,12 +369,16 @@ function _plot_eccco_mnist(
model_dict=model_dict,
wide::Bool = false,
img_height::Int = img_height,
plot_factual::Bool = false,
kwrgs...,
)
# Setup:
Random.seed!(rng)
if x isa Int
x = reshape(counterfactual_data.X[:,rand(findall(labels.==x))],input_dim,1)
x_fact = counterfactual_data.X[:,rand(findall(labels.==x))][:,:]
else
x_fact = x
end
# Generate counterfactuals using ECCCo generator:
......@@ -390,7 +394,7 @@ function _plot_eccco_mnist(
ces = Dict()
for (mod_name, mod) in model_dict
ce = generate_counterfactual(
x, target, counterfactual_data, mod, eccco_generator;
x_fact, target, counterfactual_data, mod, eccco_generator;
decision_threshold=γ, max_iter=T,
initialization=:identity,
converge_when=:generator_conditions,
......@@ -401,13 +405,17 @@ function _plot_eccco_mnist(
# Plot:
p1 = Plots.plot(
convert2image(MNIST, reshape(x,28,28)),
axis=nothing,
convert2image(MNIST, reshape(x_fact,28,28)),
axis=([], false),
size=(img_height, img_height),
title="Factual"
)
plts = []
if plot_factual
plts = [p1]
else
plts = []
end
letters = collect('a':'z')[1:length(ces)]
_count = 1
......@@ -425,9 +433,9 @@ function _plot_eccco_mnist(
_count += 1
end
if wide
plt = Plots.plot(plts...; size=(img_height*length(plts),img_height), layout=(1,length(plts)))
plt = Plots.plot(plts...; size=(img_height*length(plts),img_height), layout=(1,length(plts)), kwrgs...)
else
plt = Plots.plot(plts...; size=(img_height,img_height))
plt = Plots.plot(plts...; size=(img_height,img_height), kwrgs...)
end
return plt, eccco_generator, ces
......@@ -505,7 +513,7 @@ end
function MLJFlux.build(b::RobustNetBuilder, rng, n_in, n_out)
n_hidden, γ = b.n_hidden, b.lipschitz_bound
_n_hidden = fill(n_hidden,1)
_n_hidden = fill(n_hidden,2)
model_ps = DenseLBDNParams{Float32}(n_in, _n_hidden, n_out, γ; rng)
chain = Flux.Chain(DiffLBDN(model_ps))
return chain
......@@ -513,8 +521,8 @@ end
# Final model:
rob_net = NeuralNetworkClassifier(
builder=RobustNetBuilder(128, 5.0),
epochs=epochs,
builder=RobustNetBuilder(60, 5.0),
epochs=600,
batch_size=batch_size,
finaliser=_finaliser,
loss=_loss,
......@@ -581,11 +589,39 @@ plt_additional_models, _, _ces_ = _plot_eccco_mnist(
plt_order = _plt_order,
model_dict=large_model_dict,
wide = true,
plot_factual = true,
img_height = 150,
)
display(plt_additional_models)
savefig(plt_additional_models, joinpath(output_images_path, "mnist_eccco_additional.png"))
```
```{julia}
combos = [
(7,2),
(0,8),
(8,3),
(9,4),
]
plts = [plt_additional_models]
for (factual, target) in combos
plt, _, _ = _plot_eccco_mnist(
plt_order = _plt_order,
model_dict = large_model_dict,
x = factual,
target = target,
wide = true,
plot_factual = true,
img_height = 150
)
display(plt)
push!(plts, plt)
end
```
### All digits
```{julia}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment