moons.qmd 6.97 KiB
```{julia}
include("notebooks/setup.jl")
eval(setup_notebooks)
```
# Moons Data
```{julia}
# Hyper:
_retrain = true
# Data:
test_size = 0.2
n_obs = Int(1000 / (1.0 - test_size))
counterfactual_data, test_data = train_test_split(load_moons(n_obs; noise=0.05, factor=0.5); test_size=test_size)
X, y = CounterfactualExplanations.DataPreprocessing.unpack_data(counterfactual_data)
X = table(permutedims(X))
labels = counterfactual_data.output_encoder.labels
input_dim, n_obs = size(counterfactual_data.X)
output_dim = length(unique(labels))
```
First, let's create a couple of image classifier architectures:
```{julia}
# Model parameters:
epochs = 100
batch_size = minimum([Int(round(n_obs/10)), 128])
n_hidden = 32
activation = Flux.swish
builder = MLJFlux.MLP(
hidden=(n_hidden, n_hidden, n_hidden,),
σ=activation
)
n_ens = 5 # number of models in ensemble
_loss = Flux.Losses.logitcrossentropy # loss function
_finaliser = x -> x # finaliser function
```
```{julia}
# JEM parameters:
𝒟x = Normal()
𝒟y = Categorical(ones(output_dim) ./ output_dim)
sampler = ConditionalSampler(
𝒟x, 𝒟y,
input_size=(input_dim,),
batch_size=batch_size,
)
α = [1.0,1.0,1e-2] # penalty strengths
```
```{julia}
# Simple MLP:
mlp = NeuralNetworkClassifier(
builder=builder,
epochs=epochs,
batch_size=batch_size,
finaliser=_finaliser,
loss=_loss,
)
# Deep Ensemble:
mlp_ens = EnsembleModel(model=mlp, n=n_ens)
# Joint Energy Model:
jem = JointEnergyClassifier(
sampler;
builder=builder,
epochs=epochs,
batch_size=batch_size,
finaliser=_finaliser,
loss=_loss,
jem_training_params=(
α=α,verbosity=10,
),
sampling_steps=20,
)
# JEM with adversarial training:
jem_adv = deepcopy(jem)
# jem_adv.adv_training = true
# Deep Ensemble of Joint Energy Models:
jem_ens = EnsembleModel(model=jem, n=n_ens)
# Deep Ensemble of Joint Energy Models with adversarial training:
# jem_ens_plus = EnsembleModel(model=jem_adv, n=n_ens)
# Dictionary of models:
models = Dict(
"MLP" => mlp,
"MLP Ensemble" => mlp_ens,
"JEM" => jem,
"JEM Ensemble" => jem_ens,
# "JEM Ensemble+" => jem_ens_plus,
)
```
```{julia}
# Train models:
function _train(model, X=X, y=labels; cov=.95, method=:simple_inductive, mod_name="model")
conf_model = conformal_model(model; method=method, coverage=cov)
mach = machine(conf_model, X, y)
@info "Begin training $mod_name."
fit!(mach)
@info "Finished training $mod_name."
M = ECCCo.ConformalModel(mach.model, mach.fitresult)
return M
end
if _retrain
model_dict = Dict(mod_name => _train(model; mod_name=mod_name) for (mod_name, model) in models)
Serialization.serialize(joinpath(output_path,"moons_models.jls"), model_dict)
else
model_dict = Serialization.deserialize(joinpath(output_path,"moons_models.jls"))
end
```
```{julia}
# Evaluate models:
measure = Dict(
:f1score => multiclass_f1score,
:acc => accuracy,
:precision => multiclass_precision
)
model_performance = DataFrame()
for (mod_name, model) in model_dict
# Test performance:
_perf = CounterfactualExplanations.Models.model_evaluation(model, test_data, measure=collect(values(measure)))
_perf = DataFrame([[p] for p in _perf], collect(keys(measure)))
_perf.mod_name .= mod_name
model_performance = vcat(model_performance, _perf)
end
Serialization.serialize(joinpath(output_path,"moons_model_performance.jls"), model_performance)
CSV.write(joinpath(output_path, "moons_model_performance.csv"), model_performance)
model_performance
```
```{julia}
n_regen = 200
n_each = batch_size
for (mod_name, model) in model_dict
K = length(counterfactual_data.y_levels)
input_size = size(selectdim(counterfactual_data.X, ndims(counterfactual_data.X), 1))
𝒟x = Uniform(extrema(counterfactual_data.X)...)
𝒟y = Categorical(ones(K) ./ K)
sampler = ConditionalSampler(𝒟x, 𝒟y; input_size=input_size)
opt = ImproperSGLD()
plts = []
for target in levels(labels)
target_idx = findall(levels(labels) .== target)[1]
f(x) = logits(model, x)
X̂ = sampler(f, opt; niter=n_regen, n_samples=n_each, y=target_idx)
ex = extrema(hcat(MLJFlux.reformat(X),X̂), dims=2)
xlims = ex[1]
ylims = ex[2]
x1 = range(1.0f0.*xlims...,length=100)
x2 = range(1.0f0.*ylims...,length=100)
p(x) = probs(model, x)
plt = Plots.contour(
x1, x2, (x, y) -> p([x, y][:,:])[target_idx],
fill=true, alpha=0.5, title="Target: $target", cbar=true,
xlims=xlims,
ylims=ylims,
)
Plots.scatter!(
MLJFlux.reformat(X)[1,:], MLJFlux.reformat(X)[2,:],
color=Int.(labels.refs).-1, group=Int.(labels.refs).-1, alpha=0.5
)
Plots.scatter!(
X̂[1,:], X̂[2,:],
color=repeat([target], size(X̂,2)),
group=repeat([target], size(X̂,2)),
shape=:star5, ms=10
)
savefig(plt, joinpath(output_images_path, "moons_generated_$(mod_name).png"))
push!(plts, plt)
end
plt = Plots.plot(plts..., layout=(1, 2), size=(2*500, 400), plot_title=mod_name)
display(plt)
end
```
## Benchmark
```{julia}
# Benchmark generators:
generator_dict = Dict(
:wachter => WachterGenerator(),
:revise => REVISEGenerator(),
:greedy => GreedyGenerator(),
:eccco => ECCCoGenerator(),
)
# Measures:
measures = [
CounterfactualExplanations.distance,
ECCCo.distance_from_energy,
ECCCo.distance_from_targets,
CounterfactualExplanations.Evaluation.validity,
CounterfactualExplanations.Evaluation.redundancy,
]
bmk = benchmark(
counterfactual_data;
models=model_dict,
generators=generator_dict,
measure=measures,
suppress_training=true, dataname="Moons",
n_individuals=5,
target=0, factual=1,
initialization=:identity,
)
CSV.write(joinpath(output_path, "moons_benchmark.csv"), bmk())
```
```{julia}
@chain bmk() begin
@group_by(dataname, generator, model, variable)
@summarize(mean=mean(value),sd=std(value))
@ungroup
@filter(variable == "distance_from_energy")
end
```
```{julia}
df = @chain bmk() begin
@mutate(variable = ifelse.(variable .== "distance_from_energy", "Non-Conformity", variable))
@mutate(variable = ifelse.(variable .== "distance_from_targets", "Implausibility", variable))
@mutate(variable = ifelse.(variable .== "distance", "Cost", variable))
@mutate(variable = ifelse.(variable .== "redundancy", "Redundancy", variable))
@mutate(variable = ifelse.(variable .== "Validity", "Validity", variable))
end
plt = AlgebraOfGraphics.data(df) * visual(BoxPlot) *
mapping(:generator, :value, row=:variable, col=:model, color=:generator)
plt = draw(
plt, axis=(xlabel="", xticksvisible=false, xticklabelsvisible=false, width=150, height=120),
facet=(; linkyaxes=:none)
)
display(plt)
save(joinpath(output_images_path, "moons_benchmark.png"), plt, px_per_unit=5)
```