prototyping.qmd 3.16 KiB
```{julia}
include("$(pwd())/notebooks/setup.jl")
eval(setup_notebooks)
```
# Linearly Separable Data
```{julia}
# Hyper:
_retrain = false
# Data:
test_size = 0.2
n_obs = Int(1000 / (1.0 - test_size))
counterfactual_data, test_data = train_test_split(
load_blobs(n_obs; cluster_std=0.1, center_box=(-1. => 1.));
test_size=test_size
)
X, y = CounterfactualExplanations.DataPreprocessing.unpack_data(counterfactual_data)
X = table(permutedims(X))
labels = counterfactual_data.output_encoder.labels
input_dim, n_obs = size(counterfactual_data.X)
output_dim = length(unique(labels))
```
First, let's create a couple of image classifier architectures:
```{julia}
# Model parameters:
epochs = 100
bs = minimum([Int(round(n_obs/10)), 128])
n_hidden = 16
activation = Flux.swish
builder = MLJFlux.MLP(
hidden=(n_hidden, n_hidden, n_hidden),
σ=Flux.swish
)
n_ens = 5 # number of models in ensemble
_loss = Flux.Losses.crossentropy # loss function
_finaliser = Flux.softmax # finaliser function
```
```{julia}
# JEM parameters:
𝒟x = Normal()
𝒟y = Categorical(ones(output_dim) ./ output_dim)
sampler = ConditionalSampler(
𝒟x, 𝒟y,
input_size=(input_dim,),
batch_size=50,
)
α = [1.0,1.0,1e-1] # penalty strengths
```
```{julia}
# Joint Energy Model:
model = JointEnergyClassifier(
sampler;
builder=builder,
epochs=epochs,
batch_size=bs,
finaliser=_finaliser,
loss=_loss,
jem_training_params=(
α=α,verbosity=10,
),
sampling_steps=30,
)
```
```{julia}
conf_model = conformal_model(model; method=:simple_inductive, coverage=0.95)
mach = machine(conf_model, X, labels)
@info "Begin training model."
fit!(mach)
@info "Finished training model."
M = ECCCo.ConformalModel(mach.model, mach.fitresult)
```
```{julia}
λ₁ = 0.25
λ₂ = 0.75
λ₃ = 0.75
Λ = [λ₁, λ₂, λ₃]
opt = Flux.Optimise.Descent(0.01)
use_class_loss = false
# Benchmark generators:
generator_dict = Dict(
"ECCCo" => ECCCoGenerator(λ=Λ, opt=opt, use_class_loss=use_class_loss),
"ECCCo (energy delta)" => ECCCoGenerator(λ=Λ, opt=opt, use_class_loss=use_class_loss, use_energy_delta=true),
)
```
```{julia}
Random.seed!(2023)
X = X isa Matrix ? X : Float32.(permutedims(matrix(X)))
factual_label = levels(labels)[2]
x_factual = reshape(X[:,rand(findall(predict_label(M, counterfactual_data).==factual_label))],input_dim,1)
target = levels(labels)[1]
factual = predict_label(M, counterfactual_data, x_factual)[1]
ces = Dict{Any,Any}()
plts = []
for (name, generator) in generator_dict
ce = generate_counterfactual(
x_factual, target, counterfactual_data, M, generator;
initialization=:identity,
converge_when=:generator_conditions,
)
plt = Plots.plot(
ce, title=name, alpha=0.2,
cbar=false,
)
if contains(name, "ECCCo")
_X = distance_from_energy(ce, return_conditionals=true)
Plots.scatter!(
_X[1,:],_X[2,:], color=:purple, shape=:star5,
ms=10, label="x̂|$target", alpha=0.5
)
end
push!(plts, plt)
ces[name] = ce
end
plt = Plots.plot(plts..., size=(800,350))
display(plt)
```