Skip to content
Snippets Groups Projects
linearly_separable.qmd 9.52 KiB
```{julia}
include("$(pwd())/notebooks/setup.jl")
eval(setup_notebooks)
```

# Linearly Separable Data

```{julia}
# Hyper:
_retrain = true

# Data:
test_size = 0.2
n_obs = Int(1000 / (1.0 - test_size))
counterfactual_data, test_data = train_test_split(
    load_blobs(n_obs; cluster_std=0.1, center_box=(-1. => 1.)); 
    test_size=test_size
)
X, y = CounterfactualExplanations.DataPreprocessing.unpack_data(counterfactual_data)
X = table(permutedims(X))
labels = counterfactual_data.output_encoder.labels
input_dim, n_obs = size(counterfactual_data.X)
output_dim = length(unique(labels))
```

First, let's create a couple of image classifier architectures:

```{julia}
# Model parameters:
epochs = 100
batch_size = minimum([Int(round(n_obs/10)), 128])
n_hidden = 16
activation = Flux.swish
builder = MLJFlux.MLP(
    hidden=(n_hidden, n_hidden, n_hidden), 
    σ=Flux.swish
)
n_ens = 5                                   # number of models in ensemble
_loss = Flux.Losses.crossentropy       # loss function
_finaliser = Flux.softmax                   # finaliser function
```

```{julia}
# JEM parameters:
𝒟x = Normal()
𝒟y = Categorical(ones(output_dim) ./ output_dim)
sampler = ConditionalSampler(
    𝒟x, 𝒟y, 
    input_size=(input_dim,), 
    batch_size=50,
)
α = [1.0,1.0,1e-1]      # penalty strengths
```

```{julia}
# Simple MLP:
mlp = NeuralNetworkClassifier(
    builder=builder, 
    epochs=epochs,
    batch_size=batch_size,
    finaliser=_finaliser,
    loss=_loss,
)

# Deep Ensemble:
mlp_ens = EnsembleModel(model=mlp, n=n_ens)

# Joint Energy Model:
jem = JointEnergyClassifier(
    sampler;
    builder=builder,
    epochs=epochs,
    batch_size=batch_size,
    finaliser=_finaliser,
    loss=_loss,
    jem_training_params=(
        α=α,verbosity=10,
    ),
    sampling_steps=30,
)

# JEM with adversarial training:
jem_adv = deepcopy(jem)
# jem_adv.adv_training = true

# Deep Ensemble of Joint Energy Models:
jem_ens = EnsembleModel(model=jem, n=n_ens)

# Deep Ensemble of Joint Energy Models with adversarial training:
# jem_ens_plus = EnsembleModel(model=jem_adv, n=n_ens)

# Dictionary of models:
models = Dict(
    "MLP" => mlp,
    # "MLP Ensemble" => mlp_ens,
    "JEM" => jem,
    # "JEM Ensemble" => jem_ens,
    # "JEM Ensemble+" => jem_ens_plus,
)
```

```{julia}
# Train models:
function _train(model, X=X, y=labels; cov=0.95, method=:simple_inductive, mod_name="model")
    conf_model = conformal_model(model; method=method, coverage=cov)
    mach = machine(conf_model, X, y)
    @info "Begin training $mod_name."
    fit!(mach)
    @info "Finished training $mod_name."
    M = ECCCo.ConformalModel(mach.model, mach.fitresult)
    return M
end
if _retrain
    model_dict = Dict(mod_name => _train(model; mod_name=mod_name) for (mod_name, model) in models)
    Serialization.serialize(joinpath(output_path,"linearly_separable_models.jls"), model_dict)
else
    model_dict = Serialization.deserialize(joinpath(output_path,"linearly_separable_models.jls"))
end
```

```{julia}
# Evaluate models:

measure = Dict(
    :f1score => multiclass_f1score, 
    :acc => accuracy, 
    :precision => multiclass_precision
)
model_performance = DataFrame()
for (mod_name, model) in model_dict
    # Test performance:
    _perf = CounterfactualExplanations.Models.model_evaluation(model, test_data, measure=collect(values(measure)))
    _perf = DataFrame([[p] for p in _perf], collect(keys(measure)))
    _perf.mod_name .= mod_name
    model_performance = vcat(model_performance, _perf)
end
Serialization.serialize(joinpath(output_path,"linearly_separable_model_performance.jls"), model_performance)
CSV.write(joinpath(output_path, "linearly_separable_model_performance.csv"), model_performance)
model_performance
```

```{julia}
n_regen = 1000
n_each = batch_size
for (mod_name, model) in model_dict
    K = length(counterfactual_data.y_levels)
    input_size = size(selectdim(counterfactual_data.X, ndims(counterfactual_data.X), 1))
    𝒟x = Uniform(extrema(counterfactual_data.X)...)
    𝒟y = Categorical(ones(K) ./ K)
    sampler = ConditionalSampler(𝒟x, 𝒟y; input_size=input_size)
    opt = ImproperSGLD()
    plts = []
    for target in levels(labels)
        target_idx = findall(levels(labels) .== target)[1]
        f(x) = logits(model, x)
        X̂ = sampler(f, opt; niter=n_regen, n_samples=n_each, y=target_idx)
        ex = extrema(hcat(MLJFlux.reformat(X),X̂), dims=2)
        xlims = ex[1]
        ylims = ex[2]
        x1 = range(1.0f0.*xlims...,length=10)
        x2 = range(1.0f0.*ylims...,length=10)
        p(x) = probs(model, x)
        plt = Plots.contour(
            x1, x2, (x, y) -> p([x, y][:,:])[target_idx], 
            fill=true, alpha=0.5, title="Target: $target", cbar=true,
            xlims=xlims,
            ylims=ylims,
        )
        Plots.scatter!(
            MLJFlux.reformat(X)[1,:], MLJFlux.reformat(X)[2,:], 
            color=Int.(labels.refs), group=Int.(labels.refs), alpha=0.5
        )
        Plots.scatter!(
            X̂[1,:], X̂[2,:], 
            color=repeat([target_idx], size(X̂,2)), 
            group=repeat([target_idx], size(X̂,2)), 
            shape=:star5, ms=10
        )
        savefig(plt, joinpath(output_images_path, "linearly_separable_generated_$(mod_name).png"))
        push!(plts, plt)
    end
    plt = Plots.plot(plts..., layout=(1, 2), size=(2*500, 400), plot_title=mod_name)
    display(plt)
end
```

```{julia}
#| output: true
#| echo: false
#| label: fig-losses
#| fig-cap: "Illustration of the smooth size loss and the configurable classification loss."

X_plot = matrix(X)
temp = 0.1

for (mod_name, model) in model_dict
    p0 = Plots.contourf(model.model, model.fitresult, X_plot, labels; plot_set_size=true, zoom=0, temp=temp)
    p1 = Plots.contourf(model.model, model.fitresult, X_plot, labels; plot_set_loss=true, zoom=0, temp=temp)
    p2 = Plots.contourf(model.model, model.fitresult, X_plot, labels; plot_classification_loss=true, zoom=0, temp=temp, clim=nothing, loss_matrix=ones(2,2))
    display(Plots.plot(p0, p1, p2, size=(1400,320), plot_title=mod_name, layout=(1,3)))
end
```

## Benchmark

```{julia}
λ₁ = 0.25
λ₂ = 0.75
λ₃ = 0.75
Λ = [λ₁, λ₂, λ₃]

opt = Flux.Optimise.Descent(0.01)
use_class_loss = false

# Benchmark generators:
generator_dict = Dict(
    "Wachter" => WachterGenerator(λ=λ₁, opt=opt),
    "REVISE" => REVISEGenerator(λ=λ₁, opt=opt),
    "Schut" => GreedyGenerator(),
    "ECCCo" => ECCCoGenerator(λ=Λ, opt=opt, use_class_loss=use_class_loss),
    "ECCCo (no CP)" => ECCCoGenerator(λ=[λ₁, 0.0, λ₃], opt=opt, use_class_loss=use_class_loss),
    "ECCCo (no EBM)" => ECCCoGenerator(λ=[λ₁, λ₂, 0.0], opt=opt, use_class_loss=use_class_loss),
)
```

### POC

```{julia}
Random.seed!(2023)

M = model_dict["JEM"]
X = X isa Matrix ? X : Float32.(permutedims(matrix(X)))
factual_label =  levels(labels)[1]
x_factual = reshape(X[:,rand(findall(predict_label(M, counterfactual_data).==factual_label))],input_dim,1)
target =  levels(labels)[2]
factual = predict_label(M, counterfactual_data, x_factual)[1]

ces = Dict{Any,Any}()
plts = []
for (name, generator) in generator_dict
    ce = generate_counterfactual(
        x_factual, target, counterfactual_data, M, generator;
        initialization=:identity, 
        converge_when=:generator_conditions,
    )
    plt = Plots.plot(
        ce, title=name, alpha=0.2, 
        cbar=false, 
        # axis=nothing,
    )
    if contains(name, "ECCCo")
        _X = distance_from_energy(ce, return_conditionals=true)
        Plots.scatter!(
            _X[1,:],_X[2,:], color=:purple, shape=:star5, 
            ms=10, label="x̂|$target", alpha=0.5
        )
    end
    push!(plts, plt)
    ces[name] = ce
end
plt = Plots.plot(plts..., size=(650,500))
display(plt)
savefig(plt, joinpath(output_images_path, "linearly_separable_poc.png"))
```

### Complete Benchmark

```{julia}
# Measures:
measures = [
    CounterfactualExplanations.distance,
    ECCCo.distance_from_energy,
    ECCCo.distance_from_targets,
    CounterfactualExplanations.Evaluation.validity,
    CounterfactualExplanations.Evaluation.redundancy,
    ECCCo.set_size_penalty
]

bmks = []
for target in sort(unique(labels))
    for factual in sort(unique(labels))
        if factual == target
            continue
        end
        bmk = benchmark(
            counterfactual_data; 
            models=model_dict, 
            generators=generator_dict, 
            measure=measures,
            suppress_training=true, dataname="Linearly Separable",
            n_individuals=25,
            target=target, factual=factual,
            initialization=:identity,
            converge_when=:generator_conditions,
        )
        push!(bmks, bmk)
    end
end
bmk = reduce(vcat, bmks)
CSV.write(joinpath(output_path, "linearly_separable_benchmark.csv"), bmk())
```

```{julia}
df = @chain bmk() begin
    @mutate(variable = ifelse.(variable .== "distance_from_energy", "Non-Conformity", variable))
    @mutate(variable = ifelse.(variable .== "distance_from_targets", "Implausibility", variable))
    @mutate(variable = ifelse.(variable .== "distance", "Cost", variable))
    @mutate(variable = ifelse.(variable .== "redundancy", "Redundancy", variable))
    @mutate(variable = ifelse.(variable .== "Validity", "Validity", variable))
end
plt = AlgebraOfGraphics.data(df) * visual(BoxPlot) * 
    mapping(:generator, :value, row=:variable, col=:model, color=:generator)
plt = draw(
    plt, axis=(xlabel="", xticksvisible=false, xticklabelsvisible=false, width=150, height=120), 
    facet=(; linkyaxes=:none)
)   
display(plt)
save(joinpath(output_images_path, "linearly_separable_benchmark.png"), plt, px_per_unit=5)
```