Skip to content
Snippets Groups Projects
mnist.qmd 19.42 KiB
```{julia}
include("$(pwd())/notebooks/setup.jl")
eval(setup_notebooks)
```

# MNIST

## Anecdotal Evidence

### Examples in Introduction

#### Wachter and JSMA

```{julia}
Random.seed!(2023)

# Data:
counterfactual_data = load_mnist()
X, y = CounterfactualExplanations.DataPreprocessing.unpack_data(counterfactual_data)
input_dim, n_obs = size(counterfactual_data.X)
M = load_mnist_mlp()

# Target:
factual_label = 9
x_factual = reshape(X[:,rand(findall(predict_label(M, counterfactual_data).==factual_label))],input_dim,1)
target = 7
factual = predict_label(M, counterfactual_data, x_factual)[1]
γ = 0.9

# Training params:
T = 100
```

```{julia}
# Search:
generic_generator = WachterGenerator()
ce_wachter = generate_counterfactual(
    x_factual, target, counterfactual_data, M, generic_generator; 
    decision_threshold=γ, max_iter=T,
    initialization=:identity,
)
greedy_generator = GreedyGenerator(η=2.0)
ce_jsma = generate_counterfactual(
    x_factual, target, counterfactual_data, M, greedy_generator; 
    decision_threshold=γ, max_iter=T,
    initialization=:identity,
)
```

```{julia}
p1 = Plots.plot(
    convert2image(MNIST, reshape(x_factual,28,28)),
    axis=([], false), 
    size=(img_height, img_height),
    title="Factual"
)
plts = [p1]

ces = zip([ce_wachter,ce_jsma])
counterfactuals = reduce((x,y)->cat(x,y,dims=3),map(ce -> CounterfactualExplanations.counterfactual(ce[1]), ces))
phat = reduce((x,y) -> cat(x,y,dims=3), map(ce -> target_probs(ce[1]), ces))
for x in zip(eachslice(counterfactuals; dims=3), eachslice(phat; dims=3), ["Wachter","JSMA"])
    ce, _phat, _name = (x[1],x[2],x[3])
    _title = "$(_name) (p=$(round(_phat[1]; digits=2)))"
    plt = Plots.plot(
        convert2image(MNIST, reshape(ce,28,28)),
        axis=([], false), 
        size=(img_height, img_height),
        title=_title
    )
    plts = [plts..., plt]
end
plt = Plots.plot(plts...; size=(img_height*length(plts),img_height), layout=(1,length(plts)))
display(plt)
savefig(plt, joinpath(output_images_path, "you_may_not_like_it.png"))
```

#### REVISE

```{julia}
using CounterfactualExplanations.Models: load_mnist_vae
vae = load_mnist_vae()
vae_weak = load_mnist_vae(;strong=false)
Serialization.serialize(joinpath(output_path,"mnist_classifier.jls"), M)
Serialization.serialize(joinpath(output_path,"mnist_vae.jls"), vae)
Serialization.serialize(joinpath(output_path,"mnist_vae_weak.jls"), vae_weak)
```

```{julia}
# Define generator:
revise_generator = REVISEGenerator(
    opt = Flux.Optimise.Descent(0.25),
    λ=0.0,
)
# Generate recourse:
counterfactual_data.generative_model = vae # assign generative model
ce_strong = generate_counterfactual(
    x_factual, target, counterfactual_data, M, revise_generator; 
    decision_threshold=γ, max_iter=T,
    initialization=:identity,
    converge_when=:generator_conditions,
)
counterfactual_data_weak = deepcopy(counterfactual_data)
counterfactual_data_weak.generative_model = vae_weak
ce_weak = generate_counterfactual(
    x_factual, target, counterfactual_data_weak, M, revise_generator;
    decision_threshold=γ, max_iter=T,
    initialization=:identity,
    converge_when=:generator_conditions,
)
```

```{julia}
ces = zip([ce_strong,ce_weak])
counterfactuals = reduce((x,y)->cat(x,y,dims=3),map(ce -> CounterfactualExplanations.counterfactual(ce[1]), ces))
phat = reduce((x,y) -> cat(x,y,dims=3), map(ce -> target_probs(ce[1]), ces))
plts = [p1]
for x in zip(eachslice(counterfactuals; dims=3), eachslice(phat; dims=3), ["Strong VAE","Weak VAE"])
    ce, _phat, _name = (x[1],x[2],x[3])
    _title = "$(_name) (p=$(round(_phat[1]; digits=2)))"
    plt = Plots.plot(
        convert2image(MNIST, reshape(ce,28,28)),
        axis=([], false), 
        size=(img_height, img_height),
        title=_title
    )
    plts = [plts..., plt]
end
plt = Plots.plot(plts...; size=(img_height*length(plts),img_height), layout=(1,length(plts)))
display(plt)
savefig(plt, joinpath(output_images_path, "surrogate_gone_wrong.png"))
```

```{julia}
ces = zip([ce_wachter, ce_jsma, ce_strong])
counterfactuals = reduce((x,y)->cat(x,y,dims=3),map(ce -> CounterfactualExplanations.counterfactual(ce[1]), ces))
phat = reduce((x,y) -> cat(x,y,dims=3), map(ce -> target_probs(ce[1]), ces))
plts = [p1]
for x in zip(eachslice(counterfactuals; dims=3), eachslice(phat; dims=3), ["Wachter","Schut","REVISE"])
    ce, _phat, _name = (x[1],x[2],x[3])
    _title = "$(_name) (p=$(round(_phat[1]; digits=2)))"
    plt = Plots.plot(
        convert2image(MNIST, reshape(ce,28,28)),
        axis=([], false), 
        size=(img_height, img_height),
        title=_title
    )
    plts = [plts..., plt]
end
plt = Plots.plot(plts...; size=(0.8*panel_height*length(plts),0.8*panel_height), layout=(1,length(plts)), dpi=400)
display(plt)
savefig(plt, joinpath(output_images_path, "mnist_motivation.png"))
```

### ECCCo

```{julia}
function pre_process(x; noise::Float32=0.03f0)
    ϵ = Float32.(randn(size(x)) * noise)
    x += ϵ
    return x
end
```

```{julia}
# Hyper:
_retrain = false
_regen = true

# Data:
n_obs = 10000
counterfactual_data = load_mnist(n_obs)
counterfactual_data.X = pre_process.(counterfactual_data.X)
counterfactual_data.generative_model = vae
X, y = CounterfactualExplanations.DataPreprocessing.unpack_data(counterfactual_data)
X = table(permutedims(X))
x_factual = reshape(pre_process(x_factual, noise=0.0f0), input_dim, 1)
labels = counterfactual_data.output_encoder.labels
input_dim, n_obs = size(counterfactual_data.X)
n_digits = Int(sqrt(input_dim))
output_dim = length(unique(labels))
```

First, let's create a couple of image classifier architectures:

```{julia}
# Model parameters:
epochs = 10
batch_size = minimum([Int(round(n_obs/10)), 128])
n_hidden = 128
activation = Flux.swish
builder = MLJFlux.@builder Flux.Chain(
    Dense(n_in, n_hidden, activation),
    Dense(n_hidden, n_out),
)
n_ens = 5                                   # number of models in ensemble
_loss = Flux.Losses.crossentropy       # loss function
_finaliser = Flux.softmax                        # finaliser function
```

```{julia}
# JEM parameters:
𝒟x = Uniform(-1.0,1.0)
𝒟y = Categorical(ones(output_dim) ./ output_dim)
sampler = ConditionalSampler(
    𝒟x, 𝒟y, 
    input_size=(input_dim,), 
    batch_size=10,
)
α = [1.0,1.0,1e-2]      # penalty strengths
```

```{julia}
# Simple MLP:
mlp = NeuralNetworkClassifier(
    builder=builder, 
    epochs=epochs,
    batch_size=batch_size,
    finaliser=_finaliser,
    loss=_loss,
)

# Deep Ensemble:
mlp_ens = EnsembleModel(model=mlp, n=n_ens)

# Joint Energy Model:
jem = JointEnergyClassifier(
    sampler;
    builder=builder,
    epochs=epochs,
    batch_size=batch_size,
    finaliser=_finaliser,
    loss=_loss,
    jem_training_params=(
        α=α,verbosity=10,
    ),
    sampling_steps=25,
)

# JEM with adversarial training:
jem_adv = deepcopy(jem)
# jem_adv.adv_training = true

# Deep Ensemble of Joint Energy Models:
jem_ens = EnsembleModel(model=jem, n=n_ens)

# Deep Ensemble of Joint Energy Models with adversarial training:
# jem_ens_plus = EnsembleModel(model=jem_adv, n=n_ens)

# Dictionary of models:
models = Dict(
    "MLP" => mlp,
    "MLP Ensemble" => mlp_ens,
    "JEM" => jem,
    "JEM Ensemble" => jem_ens,
    # "JEM Ensemble+" => jem_ens_plus,
)

Serialization.serialize(joinpath(output_path,"mnist_architectures.jls"), models)
```

```{julia}
# Train models:
function _train(model, X=X, y=labels; cov=.95, method=:simple_inductive, mod_name="model")
    conf_model = conformal_model(model; method=method, coverage=cov)
    mach = machine(conf_model, X, y)
    @info "Begin training $mod_name."
    fit!(mach)
    @info "Finished training $mod_name."
    M = ECCCo.ConformalModel(mach.model, mach.fitresult)
    return M
end
if _retrain
    model_dict = Dict(mod_name => _train(mod; mod_name=mod_name) for (mod_name, mod) in models)
    Serialization.serialize(joinpath(output_path,"mnist_models.jls"), model_dict)
else
    model_dict = Serialization.deserialize(joinpath(output_path,"mnist_models.jls"))
end
```
```{julia}
params = DataFrame(
    Dict(
        :n_obs => Int.(round(n_obs/10)*10),
        :epochs => epochs,
        :batch_size => batch_size,
        :n_hidden => n_hidden,
        :n_layers => length(model_dict["MLP"].fitresult[1][1])-1,
        :activation => string(activation),
        :n_ens => n_ens,
        :lambda => string(α[3]),
        :jem_sampling_steps => jem.sampling_steps,
        :sgld_batch_size => sampler.batch_size,
        :dataname => "MNIST",
    )
)
CSV.write(joinpath(params_path, "mnist.csv"), params)
```

```{julia}
# Plot generated samples:
n_regen = 500
if _regen 
    for (mod_name, mod) in model_dict
        K = length(counterfactual_data.y_levels)
        input_size = size(selectdim(counterfactual_data.X, ndims(counterfactual_data.X), 1))
        𝒟x = Uniform(extrema(counterfactual_data.X)...)
        𝒟y = Categorical(ones(K) ./ K)
        sampler = ConditionalSampler(𝒟x, 𝒟y; input_size=input_size, prob_buffer=0.0)
        opt = ImproperSGLD()
        f(x) = logits(mod, x)

        _w = 1500
        plts = []
        neach = 10
        for i in 1:10
            x = sampler(f, opt; niter=n_regen, n_samples=neach, y=i)
            plts_i = []
            for j in 1:size(x, 2)
                xj = x[:,j]
                xj = reshape(xj, (n_digits, n_digits))
                plts_i = [plts_i..., Plots.heatmap(rotl90(xj), axis=nothing, cb=false)]
            end
            plt = Plots.plot(plts_i..., size=(_w,0.10*_w), layout=(1,10))
            plts = [plts..., plt]
        end
        plt = Plots.plot(plts..., size=(_w,_w), layout=(10,1), plot_title=mod_name)
        savefig(plt, joinpath(output_images_path, "mnist_generated_$(mod_name).png"))
        display(plt)
    end
end
```

```{julia}
# Evaluate models:

measure = Dict(
    :f1score => multiclass_f1score, 
    :acc => accuracy, 
    :precision => multiclass_precision
)
model_performance = DataFrame()
for (mod_name, mod) in model_dict
    # Test performance:
    test_data = load_mnist_test()
    _perf = CounterfactualExplanations.Models.model_evaluation(mod, test_data, measure=collect(values(measure)))
    _perf = DataFrame([[p] for p in _perf], collect(keys(measure)))
    _perf.mod_name .= mod_name
    _perf.dataname .= "MNIST"
    model_performance = vcat(model_performance, _perf)
end
Serialization.serialize(joinpath(output_path,"mnist_model_performance.jls"), model_performance)
CSV.write(joinpath(output_path, "mnist_model_performance.csv"), model_performance)
model_performance
```

### Different Models

```{julia}
function _plot_eccco_mnist(
    x::Union{AbstractArray, Int}=x_factual, target::Int=target;
    λ=[0.1,0.25,0.25],
    temp=0.1,
    plt_order = ["MLP", "MLP Ensemble", "JEM", "JEM Ensemble"],
    opt = nothing,
    rng::Union{Int,AbstractRNG}=1234,
    T::Int = 100,
    use_class_loss::Bool = true,
    model_dict=model_dict,
    wide::Bool = false,
)

    # Setup:
    Random.seed!(rng)
    if x isa Int
        x = reshape(counterfactual_data.X[:,rand(findall(labels.==x))],input_dim,1)
    end

    # Generate counterfactuals using ECCCo generator:
    eccco_generator = ECCCoGenerator(
        λ=λ, 
        temp=temp, 
        opt=opt,
        use_class_loss=use_class_loss,
        nsamples=10,
        nmin=10,
    )

    ces = Dict()
    for (mod_name, mod) in model_dict
        ce = generate_counterfactual(
            x, target, counterfactual_data, mod, eccco_generator; 
            decision_threshold=γ, max_iter=T,
            initialization=:identity,
            converge_when=:generator_conditions,
        )
        ces[mod_name] = ce
    end
    _plt_order = map(x -> findall(collect(keys(model_dict)) .== x)[1], plt_order)

    # Plot:
    p1 = Plots.plot(
        convert2image(MNIST, reshape(x,28,28)),
        axis=nothing, 
        size=(img_height, img_height),
        title="Factual"
    )

    plts = []
    letters = collect('a':'z')[1:length(ces)]
    _count = 1

    for (_name,ce) in collect(ces)[_plt_order]
        _x = CounterfactualExplanations.counterfactual(ce)
        _phat = target_probs(ce)
        _title = "($(letters[_count]))"
        plt = Plots.plot(
            convert2image(MNIST, reshape(_x,28,28)),
            axis=([], false), 
            size=(img_height, img_height),
            title=_title
        )
        plts = [plts..., plt]
        _count += 1
    end
    if wide
        plt = Plots.plot(plts...; size=(img_height*length(plts),img_height), layout=(1,length(plts)))
    else
        plt = Plots.plot(plts...; size=(img_height,img_height))
    end
    
    return plt, eccco_generator, ces
end
```

```{julia}
plt, eccco_generator, ces = _plot_eccco_mnist()
display(plt)
savefig(plt, joinpath(output_images_path, "mnist_eccco.png"))
```

#### Additional Models (not in paper)

LeNet-5:

```{julia}
mutable struct LeNetBuilder
	filter_size::Int
	channels1::Int
	channels2::Int
end

preproc(X) = reshape(X, (28, 28, 1, :))

function MLJFlux.build(b::LeNetBuilder, rng, n_in, n_out)

    _n_in = Int(sqrt(n_in))

	k, c1, c2 = b.filter_size, b.channels1, b.channels2

	mod(k, 2) == 1 || error("`filter_size` must be odd. ")

	# padding to preserve image size on convolution:
	p = div(k - 1, 2)

    preproc(x) = reshape(x, (_n_in, _n_in, 1, :))

	front = Flux.Chain(
        Conv((k, k), 1 => c1, pad=(p, p), relu),
        MaxPool((2, 2)),
        Conv((k, k), c1 => c2, pad=(p, p), relu),
        MaxPool((2, 2)),
        Flux.flatten
    )
	d = Flux.outputsize(front, (_n_in, _n_in, 1, 1)) |> first
    back = Flux.Chain(
        Dense(d, 120, relu),
        Dense(120, 84, relu),
        Dense(84, n_out),
    )

    chain = Flux.Chain(preproc, front, back)

	return chain
end

# Final model:
lenet = NeuralNetworkClassifier(
    builder=LeNetBuilder(5, 6, 16),
    epochs=epochs,
    batch_size=batch_size,
    finaliser=_finaliser,
    loss=_loss,
)
```

```{julia}
add_retrain = true

# Deep Ensemble:
mlp_large_ens = EnsembleModel(model=mlp, n=50)

# LeNet-5 Ensemble:
lenet_ens = EnsembleModel(model=lenet, n=5)

add_models = Dict(
    "LeNet-5" => lenet,
    "LeNet-5 Ensemble" => lenet_ens,
    "Large Ensemble (n=50)" => mlp_large_ens,
)

if add_retrain
    add_model_dict = Dict(mod_name => _train(mod; mod_name=mod_name) for (mod_name, mod) in add_models)
    large_model_dict = merge(model_dict, add_model_dict)
    Serialization.serialize(joinpath(output_path,"mnist_models_large.jls"), large_model_dict)
else
    large_model_dict = Serialization.deserialize(joinpath(output_path,"mnist_models_large.jls"))
end
```

```{julia}
_plt_order = [
    "MLP", 
    "MLP Ensemble", 
    "Large Ensemble (n=50)", 
    "LeNet-5",
    "LeNet-5 Ensemble",
    "JEM", 
    "JEM Ensemble",
]
plt_additional_models, _, _ces_ = _plot_eccco_mnist(
    plt_order = _plt_order,
    model_dict=large_model_dict,
    wide = true,
)
display(plt_additional_models)
savefig(plt_additional_models, joinpath(output_images_path, "mnist_eccco_additional.png"))
```

### All digits

```{julia}
function plot_mnist(
    factual::Int,target::Int;
    generator::AbstractGenerator,
    model::AbstractFittedModel=model_dict["JEM Ensemble"],
    data::CounterfactualData=counterfactual_data,
    rng::Union{Int,AbstractRNG}=Random.GLOBAL_RNG,
    _plot_title::Bool=true,
    show_factual::Bool=false,
    img_height::Int=180,
    kwargs...,
)
    Random.seed!(rng)

    decision_threshold = !isdefined(kwargs, :decision_threshold) ? 0.9 : decision_threshold
    max_iter = !isdefined(kwargs, :max_iter) ? 100 : max_iter
    initialization = !isdefined(kwargs, :initialization) ? :identity : initialization
    converge_when = !isdefined(kwargs, :converge_when) ? :generator_conditions : converge_when
    x = reshape(data.X[:,rand(findall(predict_label(model, data).==factual))],input_dim,1)
    ce = generate_counterfactual(
        x, target, data, model, generator; 
        decision_threshold=decision_threshold, max_iter=max_iter,
        initialization=initialization,
        converge_when=converge_when,
        kwargs...
    )

    _title = _plot_title ? "$(factual) -> $(target)" : ""

    _x = CounterfactualExplanations.counterfactual(ce)
    plt = Plots.plot(
        convert2image(MNIST, reshape(_x,28,28)),
        axis=([], false),
        size=(img_height, img_height),
        title=_title
    )
    if show_factual
        plt_factual = Plots.plot(
            convert2image(MNIST, reshape(x,28,28)),
            axis=([], false),
            size=(img_height, img_height),
            title="Factual"
        )
        plt = Plots.plot(plt_factual, plt; size=(img_height*2,img_height), layout=(1,2))
    end
    
    return plt
end
```

```{julia}
_regen_all_digits = false
if _regen_all_digits
    function plot_all_digits(rng=123;verbose=true,img_height=180,kwargs...)
        plts = []
        for i in 0:9
            for j in 0:9
                @info "Generating counterfactual for $(i) -> $(j)"
                plt = plot_mnist(i,j;kwargs...,rng=rng, img_height=img_height)
                !verbose || display(plt)
                plts = [plts..., plt]
            end
        end
        plt = Plots.plot(plts...; size=(img_height*10,img_height*10), layout=(10,10), dpi=300)
        return plt
    end
    plt = plot_all_digits(generator=eccco_generator)
    savefig(plt, joinpath(output_images_path, "mnist_eccco_all_digits.png"))
end
```

## Benchmark

```{julia}
Λ = eccco_generator.λ

# Benchmark generators:
generator_dict = Dict(
    "Wachter" => WachterGenerator(λ=Λ[1], opt=eccco_generator.opt),
    "REVISE" => REVISEGenerator(λ=Λ[1], opt=eccco_generator.opt),
    "Schut" => greedy_generator,
    "ECCCo" => eccco_generator,
)
```

```{julia}
generator_params = DataFrame(
    Dict(
        :λ1 => Λ[1],
        :λ2 => Λ[2],
        :λ3 => Λ[3],
        :opt => string(typeof(eccco_generator.opt)),
        :eta => eccco_generator.opt.eta,
        :dataname => "MNIST",
    )
)
CSV.write(joinpath(params_path, "generator/mnist.csv"), generator_params)
```

```{julia}
# Measures:
measures = [
    CounterfactualExplanations.distance,
    ECCCo.distance_from_energy,
    ECCCo.distance_from_targets,
    CounterfactualExplanations.Evaluation.validity,
    CounterfactualExplanations.Evaluation.redundancy,
    ECCCo.set_size_penalty,
]

bmks = []
for target in sort(unique(labels))
    for factual in sort(unique(labels))
        if factual == target
            continue
        end
        bmk = benchmark(
            counterfactual_data; 
            models=model_dict, 
            generators=generator_dict, 
            measure=measures,
            suppress_training=true, dataname="MNIST",
            n_individuals=5,
            target=target, factual=factual,
            initialization=:identity,
            converge_when=:generator_conditions,
        )
        push!(bmks, bmk)
    end
end
bmk = reduce(vcat, bmks)

CSV.write(joinpath(output_path, "mnist_benchmark.csv"), bmk())
```

```{julia}
df = @chain bmk() begin
    @mutate(variable = ifelse.(variable .== "distance_from_energy", "Non-Conformity", variable))
    @mutate(variable = ifelse.(variable .== "distance_from_targets", "Implausibility", variable))
    @mutate(variable = ifelse.(variable .== "distance", "Cost", variable))
    @mutate(variable = ifelse.(variable .== "redundancy", "Redundancy", variable))
    @mutate(variable = ifelse.(variable .== "Validity", "Validity", variable))
end
plt = AlgebraOfGraphics.data(df) * visual(BoxPlot) * 
    mapping(:generator, :value, row=:variable, col=:model, color=:generator)
plt = draw(
    plt, axis=(xlabel="", xticksvisible=false, xticklabelsvisible=false, width=150, height=120), 
    facet=(; linkyaxes=:minimal)
)   
display(plt)
save(joinpath(output_images_path, "mnist_benchmark.png"), plt, px_per_unit=5)
```