Skip to content
Snippets Groups Projects
mnist.qmd 4.67 KiB
```{julia}
include("notebooks/setup.jl")
eval(setup_notebooks)
```

# MNIST

```{julia}
# Data:
n_obs = 10000
counterfactual_data = load_mnist(n_obs)
X, y = CounterfactualExplanations.DataPreprocessing.unpack_data(counterfactual_data)
X = table(permutedims(X))
labels = counterfactual_data.output_encoder.labels
input_dim, n_obs = size(counterfactual_data.X)
n_digits = Int(sqrt(input_dim))
output_dim = length(unique(labels))
```

First, let's create a couple of image classifier architectures:

```{julia}
# Model parameters:
epochs = 100
batch_size = minimum([Int(round(n_obs/10)), 100])
n_hidden = 32
activation = Flux.relu
builder = MLJFlux.@builder Flux.Chain(
    Dense(n_in, n_hidden),
    BatchNorm(n_hidden, activation),
    Dense(n_hidden, n_out),
)
# builder = MLJFlux.Short(n_hidden=n_hidden, dropout=0.2, σ=activation)
# builder = MLJFlux.MLP(
#     hidden=(
#         n_hidden,
#         n_hidden,
#         n_hidden,
#     ), 
#     σ=activation
# )
α = [1.0,1.0,1e-2]

# Simple MLP:
mlp = NeuralNetworkClassifier(
    builder=builder, 
    epochs=epochs,
    batch_size=batch_size,
)

# Joint Energy Model:
𝒟x = Uniform(0,1)
𝒟y = Categorical(ones(output_dim) ./ output_dim)
sampler = ConditionalSampler(𝒟x, 𝒟y, input_size=(input_dim,), batch_size=batch_size)
jem = JointEnergyClassifier(
    sampler;
    builder=builder,
    batch_size=batch_size,
    finaliser=x -> x,
    loss=Flux.Losses.logitcrossentropy,
    jem_training_params=(
        α=α,verbosity=10,
        # use_gen_loss=false,
        # use_reg_loss=false,
    ),
    sampling_steps=20,
    epochs=epochs,
)

# Deep Ensemble:
mlp_ens = EnsembleModel(model=mlp, n=5)
```

```{julia}
cov = .90
conf_model = conformal_model(jem; method=:simple_inductive, coverage=cov)
mach = machine(conf_model, X, labels)
fit!(mach)
M = CCE.ConformalModel(mach.model, mach.fitresult)
```

```{julia}
jem = mach.model.model.jem
n_iter = 100
_w = 1500
plts = []
neach = 10
for i in 1:10
    x = jem.sampler(jem.chain, jem.sampling_rule; niter=n_iter, n_samples=neach, y=i)
    plts_i = []
    for j in 1:size(x, 2)
        xj = x[:,j]
        xj = reshape(xj, (n_digits, n_digits))
        plts_i = [plts_i..., Plots.heatmap(rotl90(xj), axis=nothing, cb=false)]
    end
    plt = Plots.plot(plts_i..., size=(_w,0.10*_w), layout=(1,10))
    plts = [plts..., plt]
end
plt = Plots.plot(plts..., size=(_w,_w), layout=(10,1))
display(plt)
```

```{julia}
test_data = load_mnist_test()
f1 = CounterfactualExplanations.Models.model_evaluation(M, test_data)
println("F1 score (test): $(round(f1,digits=3))")
```

```{julia}
Random.seed!(1234)

# Set up search:
factual_label = 9
x = reshape(counterfactual_data.X[:,rand(findall(predict_label(M, counterfactual_data).==factual_label))],input_dim,1)
target = 4
factual = predict_label(M, counterfactual_data, x)[1]
γ = 0.5
T = 1

# Generate counterfactual using generic generator:
generator = GenericGenerator()
ce_wachter = generate_counterfactual(
    x, target, counterfactual_data, M, generator; 
    decision_threshold=γ, max_iter=T,
    initialization=:identity,
)

# Generate counterfactual using CCE generator:
generator = CCEGenerator(
    λ=[0.0,1.0], 
    temp=0.5, 
    # opt=CounterfactualExplanations.Generators.JSMADescent(η=1.0),
)
ce_conformal = generate_counterfactual(
    x, target, counterfactual_data, M, generator; 
    decision_threshold=γ, max_iter=T,
    initialization=:identity,
    converge_when=:generator_conditions,
)
# Plot:
p1 = Plots.plot(
    convert2image(MNIST, reshape(x,28,28)),
    axis=nothing, 
    size=(img_height, img_height),
    title="Factual"
)
plts = [p1]

ces = zip([ce_wachter,ce_conformal])
counterfactuals = reduce((x,y)->cat(x,y,dims=3),map(ce -> CounterfactualExplanations.counterfactual(ce[1]), ces))
phat = reduce((x,y) -> cat(x,y,dims=3), map(ce -> target_probs(ce[1]), ces))
for x in zip(eachslice(counterfactuals; dims=3), eachslice(phat; dims=3))
    ce, _phat = (x[1],x[2])
    _title = "p(y=$(target)|x′)=$(round(_phat[1]; digits=3))"
    plt = Plots.plot(
        convert2image(MNIST, reshape(ce,28,28)),
        axis=nothing, 
        size=(img_height, img_height),
        title=_title
    )
    plts = [plts..., plt]
end
plt = Plots.plot(plts...; size=(img_height*length(plts),img_height), layout=(1,length(plts)))
display(plt)
savefig(plt, joinpath(www_path, "cce_mnist.png"))
```

## Benchmark

```{julia}
# Benchmark generators:
generators = Dict(
    :wachter => GenericGenerator(opt=opt, λ=l2_λ),
    :revise => REVISEGenerator(opt=opt, λ=l2_λ),
    :greedy => GreedyGenerator(),
)

# Conformal Models: 


# Measures:
measures = [
    CounterfactualExplanations.distance,
    CCE.distance_from_energy,
    CCE.distance_from_targets,
    CounterfactualExplanations.validity,
]
```