mnist.qmd 4.67 KiB
```{julia}
include("notebooks/setup.jl")
eval(setup_notebooks)
```
# MNIST
```{julia}
# Data:
n_obs = 10000
counterfactual_data = load_mnist(n_obs)
X, y = CounterfactualExplanations.DataPreprocessing.unpack_data(counterfactual_data)
X = table(permutedims(X))
labels = counterfactual_data.output_encoder.labels
input_dim, n_obs = size(counterfactual_data.X)
n_digits = Int(sqrt(input_dim))
output_dim = length(unique(labels))
```
First, let's create a couple of image classifier architectures:
```{julia}
# Model parameters:
epochs = 100
batch_size = minimum([Int(round(n_obs/10)), 100])
n_hidden = 32
activation = Flux.relu
builder = MLJFlux.@builder Flux.Chain(
Dense(n_in, n_hidden),
BatchNorm(n_hidden, activation),
Dense(n_hidden, n_out),
)
# builder = MLJFlux.Short(n_hidden=n_hidden, dropout=0.2, σ=activation)
# builder = MLJFlux.MLP(
# hidden=(
# n_hidden,
# n_hidden,
# n_hidden,
# ),
# σ=activation
# )
α = [1.0,1.0,1e-2]
# Simple MLP:
mlp = NeuralNetworkClassifier(
builder=builder,
epochs=epochs,
batch_size=batch_size,
)
# Joint Energy Model:
𝒟x = Uniform(0,1)
𝒟y = Categorical(ones(output_dim) ./ output_dim)
sampler = ConditionalSampler(𝒟x, 𝒟y, input_size=(input_dim,), batch_size=batch_size)
jem = JointEnergyClassifier(
sampler;
builder=builder,
batch_size=batch_size,
finaliser=x -> x,
loss=Flux.Losses.logitcrossentropy,
jem_training_params=(
α=α,verbosity=10,
# use_gen_loss=false,
# use_reg_loss=false,
),
sampling_steps=20,
epochs=epochs,
)
# Deep Ensemble:
mlp_ens = EnsembleModel(model=mlp, n=5)
```
```{julia}
cov = .90
conf_model = conformal_model(jem; method=:simple_inductive, coverage=cov)
mach = machine(conf_model, X, labels)
fit!(mach)
M = CCE.ConformalModel(mach.model, mach.fitresult)
```
```{julia}
jem = mach.model.model.jem
n_iter = 100
_w = 1500
plts = []
neach = 10
for i in 1:10
x = jem.sampler(jem.chain, jem.sampling_rule; niter=n_iter, n_samples=neach, y=i)
plts_i = []
for j in 1:size(x, 2)
xj = x[:,j]
xj = reshape(xj, (n_digits, n_digits))
plts_i = [plts_i..., Plots.heatmap(rotl90(xj), axis=nothing, cb=false)]
end
plt = Plots.plot(plts_i..., size=(_w,0.10*_w), layout=(1,10))
plts = [plts..., plt]
end
plt = Plots.plot(plts..., size=(_w,_w), layout=(10,1))
display(plt)
```
```{julia}
test_data = load_mnist_test()
f1 = CounterfactualExplanations.Models.model_evaluation(M, test_data)
println("F1 score (test): $(round(f1,digits=3))")
```
```{julia}
Random.seed!(1234)
# Set up search:
factual_label = 9
x = reshape(counterfactual_data.X[:,rand(findall(predict_label(M, counterfactual_data).==factual_label))],input_dim,1)
target = 4
factual = predict_label(M, counterfactual_data, x)[1]
γ = 0.5
T = 1
# Generate counterfactual using generic generator:
generator = GenericGenerator()
ce_wachter = generate_counterfactual(
x, target, counterfactual_data, M, generator;
decision_threshold=γ, max_iter=T,
initialization=:identity,
)
# Generate counterfactual using CCE generator:
generator = CCEGenerator(
λ=[0.0,1.0],
temp=0.5,
# opt=CounterfactualExplanations.Generators.JSMADescent(η=1.0),
)
ce_conformal = generate_counterfactual(
x, target, counterfactual_data, M, generator;
decision_threshold=γ, max_iter=T,
initialization=:identity,
converge_when=:generator_conditions,
)
# Plot:
p1 = Plots.plot(
convert2image(MNIST, reshape(x,28,28)),
axis=nothing,
size=(img_height, img_height),
title="Factual"
)
plts = [p1]
ces = zip([ce_wachter,ce_conformal])
counterfactuals = reduce((x,y)->cat(x,y,dims=3),map(ce -> CounterfactualExplanations.counterfactual(ce[1]), ces))
phat = reduce((x,y) -> cat(x,y,dims=3), map(ce -> target_probs(ce[1]), ces))
for x in zip(eachslice(counterfactuals; dims=3), eachslice(phat; dims=3))
ce, _phat = (x[1],x[2])
_title = "p(y=$(target)|x′)=$(round(_phat[1]; digits=3))"
plt = Plots.plot(
convert2image(MNIST, reshape(ce,28,28)),
axis=nothing,
size=(img_height, img_height),
title=_title
)
plts = [plts..., plt]
end
plt = Plots.plot(plts...; size=(img_height*length(plts),img_height), layout=(1,length(plts)))
display(plt)
savefig(plt, joinpath(www_path, "cce_mnist.png"))
```
## Benchmark
```{julia}
# Benchmark generators:
generators = Dict(
:wachter => GenericGenerator(opt=opt, λ=l2_λ),
:revise => REVISEGenerator(opt=opt, λ=l2_λ),
:greedy => GreedyGenerator(),
)
# Conformal Models:
# Measures:
measures = [
CounterfactualExplanations.distance,
CCE.distance_from_energy,
CCE.distance_from_targets,
CounterfactualExplanations.validity,
]
```