Skip to content
Snippets Groups Projects
Commit 5271ced9 authored by Pat Alt's avatar Pat Alt
Browse files

progres on granular examination

parent a3319dfe
No related branches found
No related tags found
No related merge requests found
......@@ -18,13 +18,14 @@ datasets = Dict(
)
# Hyperparameters:
cvgs = [0.5, 0.75, 0.90, 0.95, 0.99]
temps = [0.05, 0.1, 0.5, 1.0, 2.0]
Λ = [0.1, 1.0, 5.0]
cvgs = [0.5, 0.75, 0.95]
temps = [0.01, 0.1, 1.0]
Λ = [0.0, 0.1, 1.0]
l2_λ = 0.1
# Classifiers:
epochs = 100
link_fun = sigmoid
epochs = 250
link_fun = relu
logreg = NeuralNetworkClassifier(builder=MLJFlux.Linear(σ=link_fun), epochs=epochs)
mlp = NeuralNetworkClassifier(builder=MLJFlux.MLP(hidden=(32,), σ=link_fun), epochs=epochs)
ensmbl = EnsembleModel(model=mlp, n=5)
......@@ -37,6 +38,9 @@ classifiers = Dict(
# Search parameters:
target = 2
factual = 1
max_iter = 50
gradient_tol = 1e-2
opt = Descent(0.01)
```
```{julia}
......@@ -46,6 +50,7 @@ for (dataname, data) in datasets
# Data:
X = table(permutedims(data.X))
y = data.output_encoder.labels
x = select_factual(data,rand(1:size(data.X,2)))
for (clf_name, clf) in classifiers, cov in cvgs
......@@ -56,21 +61,21 @@ for (dataname, data) in datasets
M = CCE.ConformalModel(mach.model, mach.fitresult)
# Set up CCE:
yhat = predict_label(M, data)
factual_label = data.y_levels[factual]
target_label = data.y_levels[target]
x = select_factual(data,rand(findall(yhat .== factual_label)))
factual_label = predict_label(M, data, x)[1]
target_label = data.y_levels[data.y_levels .!= factual_label][1]
for λ in Λ, temp in temps
# CCE for given classifier, coverage, temperature and λ:
generator = CCEGenerator(temp=temp, λ=λ)
generator = CCEGenerator(temp=temp, λ=[l2_λ,λ], opt=opt)
@assert predict_label(M, data, x) != target_label
ce = try
generate_counterfactual(
x, target_label, data, M, generator;
initialization=:identity,
converge_when=:generator_conditions,
gradient_tol=gradient_tol,
max_iter=max_iter,
)
catch
missing
......@@ -82,7 +87,7 @@ for (dataname, data) in datasets
coverage = cov,
temperature = temp,
λ = λ,
counterfactual = ce,
ce = ce,
factual = factual_label,
target = target_label,
)
......@@ -96,5 +101,30 @@ end
```
```{julia}
function plot_ce(results; dataset=:multi_class, classifier=:mlp, λ=0.1, img_height=300, zoom=-0.2)
df_plot = results[results.dataset .== dataset,:] |>
res -> res[res.classifier .== classifier,:] |>
res -> res[res.λ .== λ,:]
plts = map(eachrow(df_plot)) do row
Plots.plot(
row.ce,
title="cov: $(row.coverage), temp: $(row.temperature)",
cbar=false,
zoom=zoom,
legend=false,
)
end
nrow = length(cvgs)
ncol = length(temps)
_layout = (nrow, ncol)
Plots.plot(
plts...,
size=img_height.*reverse(_layout), layout=_layout,
plot_title="λ: $λ, dataset: $dataset, classifier: $classifier",
)
end
```
```{julia}
plot_ce(results; λ=Λ[1], dataset=:moons)
```
using CounterfactualExplanations.Objectives
"Constructor for `CCEGenerator`."
function CCEGenerator(; λ::Union{AbstractFloat,Vector{<:AbstractFloat}}=[0.1, 1.0], κ::Real=0.0, temp::Real=0.05, kwargs...)
function CCEGenerator(; λ::Union{AbstractFloat,Vector{<:AbstractFloat}}=[0.1, 1.0], κ::Real=1.0, temp::Real=0.05, kwargs...)
function _set_size_penalty(ce::AbstractCounterfactualExplanation)
return CCE.set_size_penalty(ce; κ=κ, temp=temp)
end
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment