Skip to content
Snippets Groups Projects
prototyping.qmd 3.16 KiB
```{julia}
include("$(pwd())/notebooks/setup.jl")
eval(setup_notebooks)
```

# Linearly Separable Data

```{julia}
# Hyper:
_retrain = false

# Data:
test_size = 0.2
n_obs = Int(1000 / (1.0 - test_size))
counterfactual_data, test_data = train_test_split(
    load_blobs(n_obs; cluster_std=0.1, center_box=(-1. => 1.)); 
    test_size=test_size
)
X, y = CounterfactualExplanations.DataPreprocessing.unpack_data(counterfactual_data)
X = table(permutedims(X))
labels = counterfactual_data.output_encoder.labels
input_dim, n_obs = size(counterfactual_data.X)
output_dim = length(unique(labels))
```

First, let's create a couple of image classifier architectures:

```{julia}
# Model parameters:
epochs = 100
bs = minimum([Int(round(n_obs/10)), 128])
n_hidden = 16
activation = Flux.swish
builder = MLJFlux.MLP(
    hidden=(n_hidden, n_hidden, n_hidden), 
    σ=Flux.swish
)
n_ens = 5                                   # number of models in ensemble
_loss = Flux.Losses.crossentropy       # loss function
_finaliser = Flux.softmax                   # finaliser function
```

```{julia}
# JEM parameters:
𝒟x = Normal()
𝒟y = Categorical(ones(output_dim) ./ output_dim)
sampler = ConditionalSampler(
    𝒟x, 𝒟y, 
    input_size=(input_dim,), 
    batch_size=50,
)
α = [1.0,1.0,1e-1]      # penalty strengths
```


```{julia}
# Joint Energy Model:
model = JointEnergyClassifier(
    sampler;
    builder=builder,
    epochs=epochs,
    batch_size=bs,
    finaliser=_finaliser,
    loss=_loss,
    jem_training_params=(
        α=α,verbosity=10,
    ),
    sampling_steps=30,
)
```

```{julia}
conf_model = conformal_model(model; method=:simple_inductive, coverage=0.95)
mach = machine(conf_model, X, labels)
@info "Begin training model."
fit!(mach)
@info "Finished training model."
M = ECCCo.ConformalModel(mach.model, mach.fitresult)
```

```{julia}
λ₁ = 0.25
λ₂ = 0.75
λ₃ = 0.75
Λ = [λ₁, λ₂, λ₃]

opt = Flux.Optimise.Descent(0.01)
use_class_loss = false

# Benchmark generators:
generator_dict = Dict(
    "ECCCo" => ECCCoGenerator(λ=Λ, opt=opt, use_class_loss=use_class_loss),
    "ECCCo (energy delta)" => ECCCoGenerator(λ=Λ, opt=opt, use_class_loss=use_class_loss, use_energy_delta=true),
)
```

```{julia}
Random.seed!(2023)

X = X isa Matrix ? X : Float32.(permutedims(matrix(X)))
factual_label =  levels(labels)[2]
x_factual = reshape(X[:,rand(findall(predict_label(M, counterfactual_data).==factual_label))],input_dim,1)
target =  levels(labels)[1]
factual = predict_label(M, counterfactual_data, x_factual)[1]

ces = Dict{Any,Any}()
plts = []
for (name, generator) in generator_dict
    ce = generate_counterfactual(
        x_factual, target, counterfactual_data, M, generator;
        initialization=:identity, 
        converge_when=:generator_conditions,
    )
    plt = Plots.plot(
        ce, title=name, alpha=0.2, 
        cbar=false, 
    )
    if contains(name, "ECCCo")
        _X = distance_from_energy(ce, return_conditionals=true)
        Plots.scatter!(
            _X[1,:],_X[2,:], color=:purple, shape=:star5, 
            ms=10, label="x̂|$target", alpha=0.5
        )
    end
    push!(plts, plt)
    ces[name] = ce
end
plt = Plots.plot(plts..., size=(800,350))
display(plt)
```