Skip to content
Snippets Groups Projects
Commit 860ded9e authored by Pat Alt's avatar Pat Alt
Browse files

added lin sep

parent e70f6a9e
No related branches found
No related tags found
No related merge requests found
......@@ -180,17 +180,6 @@ bmk = benchmark(
CSV.write(joinpath(output_path, "cal_housing_benchmark.csv"), bmk())
```
```{julia}
@chain bmk() begin
@group_by(dataname, generator, model, variable)
@summarize(mean=mean(value),sd=std(value))
@ungroup
@filter(variable == "distance_from_energy")
end
```
```{julia}
df = @chain bmk() begin
@mutate(variable = ifelse.(variable .== "distance_from_energy", "Non-Conformity", variable))
......
......@@ -196,14 +196,14 @@ generator_dict = Dict(
"REVISE" => REVISEGenerator(λ=λ₁),
"Schut" => GreedyGenerator(),
"ECCCo" => ECCCoGenerator(λ=Λ),
"ECCCo (no CP)" => ECCCoGenerator(λ=[λ₁, 0.0, λ₃]),
"ECCCo (no EBM)" => ECCCoGenerator(λ=[λ₁, λ₂, 0.0]),
)
```
### POC
```{julia}
Random.seed!(2023)
M = model_dict["JEM"]
X = X isa Matrix ? X : Float32.(permutedims(matrix(X)))
factual_label = levels(labels)[1]
......@@ -248,31 +248,30 @@ measures = [
ECCCo.set_size_penalty
]
bmk = benchmark(
counterfactual_data;
models=model_dict,
generators=generator_dict,
measure=measures,
suppress_training=true, dataname="Circles",
n_individuals=5,
target=0, factual=1,
initialization=:identity,
converge_when=:generator_conditions,
)
CSV.write(joinpath(output_path, "circles_benchmark.csv"), bmk())
```
```{julia}
@chain bmk() begin
@group_by(dataname, generator, model, variable)
@summarize(mean=mean(value),sd=std(value))
@ungroup
@filter(variable == "distance_from_energy")
bmks = []
for target in sort(unique(labels))
for factual in sort(unique(labels))
if factual == target
continue
end
bmk = benchmark(
counterfactual_data;
models=model_dict,
generators=generator_dict,
measure=measures,
suppress_training=true, dataname="Circles",
n_individuals=5,
target=target, factual=factual,
initialization=:identity,
converge_when=:generator_conditions,
)
push!(bmks, bmk)
end
end
bmk = reduce(vcat, bmks)
CSV.write(joinpath(output_path, "circles_benchmark.csv"), bmk())
```
```{julia}
df = @chain bmk() begin
@mutate(variable = ifelse.(variable .== "distance_from_energy", "Non-Conformity", variable))
......
```{julia}
include("notebooks/setup.jl")
eval(setup_notebooks)
```
# Circles Data
```{julia}
# Hyper:
_retrain = true
# Data:
test_size = 0.2
n_obs = Int(1000 / (1.0 - test_size))
counterfactual_data, test_data = train_test_split(load_linearly_separable(n_obs); test_size=test_size)
X, y = CounterfactualExplanations.DataPreprocessing.unpack_data(counterfactual_data)
X = table(permutedims(X))
labels = counterfactual_data.output_encoder.labels
input_dim, n_obs = size(counterfactual_data.X)
output_dim = length(unique(labels))
```
First, let's create a couple of image classifier architectures:
```{julia}
# Model parameters:
epochs = 100
batch_size = minimum([Int(round(n_obs/10)), 128])
n_hidden = 16
activation = Flux.swish
builder = MLJFlux.MLP(
hidden=(n_hidden, n_hidden, n_hidden,),
σ=activation
)
n_ens = 5 # number of models in ensemble
_loss = Flux.Losses.logitcrossentropy # loss function
_finaliser = x -> x # finaliser function
```
```{julia}
# JEM parameters:
𝒟x = Normal()
𝒟y = Categorical(ones(output_dim) ./ output_dim)
sampler = ConditionalSampler(
𝒟x, 𝒟y,
input_size=(input_dim,),
batch_size=batch_size,
)
α = [1.0,1.0,1e-2] # penalty strengths
```
```{julia}
# Simple MLP:
mlp = NeuralNetworkClassifier(
builder=builder,
epochs=epochs,
batch_size=batch_size,
finaliser=_finaliser,
loss=_loss,
)
# Deep Ensemble:
mlp_ens = EnsembleModel(model=mlp, n=n_ens)
# Joint Energy Model:
jem = JointEnergyClassifier(
sampler;
builder=builder,
epochs=epochs,
batch_size=batch_size,
finaliser=_finaliser,
loss=_loss,
jem_training_params=(
α=α,verbosity=10,
),
sampling_steps=20,
)
# JEM with adversarial training:
jem_adv = deepcopy(jem)
# jem_adv.adv_training = true
# Deep Ensemble of Joint Energy Models:
jem_ens = EnsembleModel(model=jem, n=n_ens)
# Deep Ensemble of Joint Energy Models with adversarial training:
# jem_ens_plus = EnsembleModel(model=jem_adv, n=n_ens)
# Dictionary of models:
models = Dict(
"MLP" => mlp,
# "MLP Ensemble" => mlp_ens,
"JEM" => jem,
# "JEM Ensemble" => jem_ens,
# "JEM Ensemble+" => jem_ens_plus,
)
```
```{julia}
# Train models:
function _train(model, X=X, y=labels; cov=.95, method=:simple_inductive, mod_name="model")
conf_model = conformal_model(model; method=method, coverage=cov)
mach = machine(conf_model, X, y)
@info "Begin training $mod_name."
fit!(mach)
@info "Finished training $mod_name."
M = ECCCo.ConformalModel(mach.model, mach.fitresult)
return M
end
if _retrain
model_dict = Dict(mod_name => _train(model; mod_name=mod_name) for (mod_name, model) in models)
Serialization.serialize(joinpath(output_path,"circles_models.jls"), model_dict)
else
model_dict = Serialization.deserialize(joinpath(output_path,"circles_models.jls"))
end
```
```{julia}
# Evaluate models:
measure = Dict(
:f1score => multiclass_f1score,
:acc => accuracy,
:precision => multiclass_precision
)
model_performance = DataFrame()
for (mod_name, model) in model_dict
# Test performance:
_perf = CounterfactualExplanations.Models.model_evaluation(model, test_data, measure=collect(values(measure)))
_perf = DataFrame([[p] for p in _perf], collect(keys(measure)))
_perf.mod_name .= mod_name
model_performance = vcat(model_performance, _perf)
end
Serialization.serialize(joinpath(output_path,"circles_model_performance.jls"), model_performance)
CSV.write(joinpath(output_path, "circles_model_performance.csv"), model_performance)
model_performance
```
```{julia}
n_regen = 200
n_each = batch_size
for (mod_name, model) in model_dict
K = length(counterfactual_data.y_levels)
input_size = size(selectdim(counterfactual_data.X, ndims(counterfactual_data.X), 1))
𝒟x = Uniform(extrema(counterfactual_data.X)...)
𝒟y = Categorical(ones(K) ./ K)
sampler = ConditionalSampler(𝒟x, 𝒟y; input_size=input_size)
opt = ImproperSGLD()
plts = []
for target in levels(labels)
target_idx = findall(levels(labels) .== target)[1]
f(x) = logits(model, x)
X̂ = sampler(f, opt; niter=n_regen, n_samples=n_each, y=target_idx)
ex = extrema(hcat(MLJFlux.reformat(X),X̂), dims=2)
xlims = ex[1]
ylims = ex[2]
x1 = range(1.0f0.*xlims...,length=100)
x2 = range(1.0f0.*ylims...,length=100)
p(x) = probs(model, x)
plt = Plots.contour(
x1, x2, (x, y) -> p([x, y][:,:])[target_idx],
fill=true, alpha=0.5, title="Target: $target", cbar=true,
xlims=xlims,
ylims=ylims,
)
Plots.scatter!(
MLJFlux.reformat(X)[1,:], MLJFlux.reformat(X)[2,:],
color=Int.(labels.refs).-1, group=Int.(labels.refs).-1, alpha=0.5
)
Plots.scatter!(
X̂[1,:], X̂[2,:],
color=repeat([target], size(X̂,2)),
group=repeat([target], size(X̂,2)),
shape=:star5, ms=10
)
savefig(plt, joinpath(output_images_path, "circles_generated_$(mod_name).png"))
push!(plts, plt)
end
plt = Plots.plot(plts..., layout=(1, 2), size=(2*500, 400), plot_title=mod_name)
display(plt)
end
```
## Benchmark
```{julia}
λ₁ = 0.25
λ₂ = 0.75
λ₃ = 0.75
Λ = [λ₁, λ₂, λ₃]
# Benchmark generators:
generator_dict = Dict(
"Wachter" => WachterGenerator(λ=λ₁),
"REVISE" => REVISEGenerator(λ=λ₁),
"Schut" => GreedyGenerator(),
"ECCCo" => ECCCoGenerator(λ=Λ),
"ECCCo (no CP)" => ECCCoGenerator(λ=[λ₁, 0.0, λ₃]),
"ECCCo (no EBM)" => ECCCoGenerator(λ=[λ₁, λ₂, 0.0]),
)
```
### POC
```{julia}
M = model_dict["JEM"]
X = X isa Matrix ? X : Float32.(permutedims(matrix(X)))
factual_label = levels(labels)[1]
x_factual = reshape(X[:,rand(findall(predict_label(M, counterfactual_data).==factual_label))],input_dim,1)
target = levels(labels)[2]
factual = predict_label(M, counterfactual_data, x_factual)[1]
ces = Dict{Any,Any}()
plts = []
for (name, generator) in generator_dict
ce = generate_counterfactual(
x_factual, target, counterfactual_data, M, generator;
initialization=:identity,
converge_when=:generator_conditions,
)
plt = Plots.plot(ce, title=name, alpha=0.2, cbar=false, axis=nothing)
if name == "ECCCo"
_X = distance_from_energy(ce, return_conditionals=true)
Plots.scatter!(
_X[1,:],_X[2,:], color=:purple, shape=:star5,
ms=10, label="x̂|$target", alpha=0.5
)
end
push!(plts, plt)
ces[name] = ce
end
plt = Plots.plot(plts..., size=(500,520))
display(plt)
savefig(plt, joinpath(output_images_path, "circles_poc.png"))
```
### Complete Benchmark
```{julia}
# Measures:
measures = [
CounterfactualExplanations.distance,
ECCCo.distance_from_energy,
ECCCo.distance_from_targets,
CounterfactualExplanations.Evaluation.validity,
CounterfactualExplanations.Evaluation.redundancy,
ECCCo.set_size_penalty
]
bmks = []
for target in sort(unique(labels))
for factual in sort(unique(labels))
if factual == target
continue
end
bmk = benchmark(
counterfactual_data;
models=model_dict,
generators=generator_dict,
measure=measures,
suppress_training=true, dataname="Circles",
n_individuals=5,
target=target, factual=factual,
initialization=:identity,
converge_when=:generator_conditions,
)
push!(bmks, bmk)
end
end
bmk = reduce(vcat, bmks)
CSV.write(joinpath(output_path, "circles_benchmark.csv"), bmk())
```
```{julia}
df = @chain bmk() begin
@mutate(variable = ifelse.(variable .== "distance_from_energy", "Non-Conformity", variable))
@mutate(variable = ifelse.(variable .== "distance_from_targets", "Implausibility", variable))
@mutate(variable = ifelse.(variable .== "distance", "Cost", variable))
@mutate(variable = ifelse.(variable .== "redundancy", "Redundancy", variable))
@mutate(variable = ifelse.(variable .== "Validity", "Validity", variable))
end
plt = AlgebraOfGraphics.data(df) * visual(BoxPlot) *
mapping(:generator, :value, row=:variable, col=:model, color=:generator)
plt = draw(
plt, axis=(xlabel="", xticksvisible=false, xticklabelsvisible=false, width=150, height=120),
facet=(; linkyaxes=:none)
)
display(plt)
save(joinpath(output_images_path, "circles_benchmark.png"), plt, px_per_unit=5)
```
\ No newline at end of file
......@@ -196,6 +196,8 @@ generator_dict = Dict(
"REVISE" => REVISEGenerator(λ=λ₁),
"Schut" => GreedyGenerator(),
"ECCCo" => ECCCoGenerator(λ=Λ),
"ECCCo (no CP)" => ECCCoGenerator(λ=[λ₁, 0.0, λ₃]),
"ECCCo (no EBM)" => ECCCoGenerator(λ=[λ₁, λ₂, 0.0]),
)
```
......@@ -250,30 +252,30 @@ measures = [
ECCCo.set_size_penalty
]
bmk = benchmark(
counterfactual_data;
models=model_dict,
generators=generator_dict,
measure=measures,
suppress_training=true, dataname="Moons",
n_individuals=5,
target=1, factual=0,
initialization=:identity,
converge_when=:generator_conditions,
)
CSV.write(joinpath(output_path, "moons_benchmark.csv"), bmk())
```
```{julia}
@chain bmk() begin
@group_by(dataname, generator, model, variable)
@summarize(mean=mean(value),sd=std(value))
@ungroup
@filter(variable == "distance_from_energy")
bmks = []
for target in sort(unique(labels))
for factual in sort(unique(labels))
if factual == target
continue
end
bmk = benchmark(
counterfactual_data;
models=model_dict,
generators=generator_dict,
measure=measures,
suppress_training=true, dataname="Moons",
n_individuals=5,
target=target, factual=factual,
initialization=:identity,
converge_when=:generator_conditions,
)
push!(bmks, bmk)
end
end
bmk = reduce(vcat, bmks)
CSV.write(joinpath(output_path, "moons_benchmark.csv"), bmk())
```
```{julia}
df = @chain bmk() begin
@mutate(variable = ifelse.(variable .== "distance_from_energy", "Non-Conformity", variable))
......
```{julia}
include("notebooks/setup.jl")
eval(setup_notebooks)
```
# Results
```{julia}
df = DataFrame()
for _file in filter(endswith("_benchmark.csv"), readdir(output_path))
x = joinpath(output_path, _file)
_df = CSV.read(x, DataFrame)
df = vcat(df, _df, cols=:intersect)
end
synth = ("Moons", "Circles")
df.source .= ifelse.(df.dataname .∈ [synth], :synthetic, :real)
```
```{julia}
tab = @chain df begin
@group_by(dataname, generator, model, variable, source)
@summarize(mean=mean(value),sd=std(value))
@ungroup
@filter(variable ∈ ["distance_from_energy", "distance_from_targets"])
@mutate(variable = ifelse.(variable .== "distance_from_energy", "Non-Conformity", variable))
@mutate(variable = ifelse.(variable .== "distance_from_targets", "Implausibility", variable))
end
```
\ No newline at end of file
No preview for this file type
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment