Skip to content
Snippets Groups Projects
Commit 04174b2a authored by Pat Alt's avatar Pat Alt
Browse files

more work on granular stuff

parent df4c7725
No related branches found
No related tags found
No related merge requests found
This diff is collapsed.
......@@ -2,7 +2,7 @@
julia_version = "1.8.5"
manifest_format = "2.0"
project_hash = "028f8f6a9b7bdf90783ab0490b56543ce7f8c8d6"
project_hash = "512d9080e47cd18c9e3a640716f9947dd8512bcb"
[[deps.AbstractFFTs]]
deps = ["ChainRulesCore", "LinearAlgebra"]
......
......@@ -7,6 +7,7 @@ CategoricalDistributions = "af321ab8-2d2e-40a6-b165-3d674595d28e"
Chain = "8be319e6-bccf-4806-a6f7-6fae938471bc"
ConformalPrediction = "98bfc277-1877-43dc-819b-a3e38c30242f"
CounterfactualExplanations = "2f13d31b-18db-44c1-bc43-ebaf2cff0be0"
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
Images = "916415d5-f1e6-5110-898d-aaa5f9f070e0"
......
......@@ -4,7 +4,7 @@ setup_notebooks = quote
Pkg.activate("notebooks")
using AlgebraOfGraphics
using AlgebraOfGraphics: Violin, BoxPlot
using AlgebraOfGraphics: Violin, BoxPlot, BarPlot
using CairoMakie
using CCE
using CCE: set_size_penalty, distance_from_energy, distance_from_targets
......@@ -12,10 +12,11 @@ setup_notebooks = quote
using ConformalPrediction
using CounterfactualExplanations
using CounterfactualExplanations.Data
using CounterfactualExplanations.Evaluation: benchmark
using CounterfactualExplanations.Models: load_mnist_mlp
using CounterfactualExplanations.Evaluation: benchmark, evaluate
using CounterfactualExplanations.Models: load_mnist_mlp, train
using CounterfactualExplanations.Objectives
using CSV
using DataFrames
using Distributions
using Flux
using Images
......
......@@ -20,7 +20,7 @@ datasets = Dict(
# Hyperparameters:
cvgs = [0.5, 0.75, 0.95]
temps = [0.01, 0.1, 1.0]
Λ = [0.0, 0.1, 1.0]
Λ = [0.0, 0.1, 1.0, 10.0]
l2_λ = 0.1
# Classifiers:
......@@ -150,24 +150,153 @@ end
## Benchmark
```{julia}
# Benchmark generators:
generators = Dict(
:wachter => GenericGenerator(opt=opt, λ=l2_λ),
:revise => REVISEGenerator(opt=opt, λ=l2_λ),
:greedy => GreedyGenerator(opt=opt),
:greedy => GreedyGenerator(),
)
# Untrained Models:
models = Dict("cov$(cov)" => CCE.ConformalModel(conformal_model(mlp; method=:simple_inductive, coverage=cov)) for cov in cvgs)
```
models = Dict(Symbol("cov$(Int(100*cov))") => CCE.ConformalModel(conformal_model(mlp; method=:simple_inductive, coverage=cov)) for cov in cvgs)
```{julia}
bmks = []
# Measures:
measures = [
CounterfactualExplanations.distance,
CCE.distance_from_energy,
CCE.distance_from_targets,
CounterfactualExplanations.validity,
]
```
### Single CE
```{julia}
#| echo: false
_temp = 0.01
results = DataFrame()
for (dataname, data) in datasets
# Data:
X = table(permutedims(data.X))
y = data.output_encoder.labels
x = select_factual(data,rand(1:size(data.X,2)))
for (modelname, M) in deepcopy(models)
# Model training:
M = train(M, data)
# Set up CCE:
factual_label = predict_label(M, data, x)[1]
target_label = data.y_levels[data.y_levels .!= factual_label][1]
for λ in Λ
# Generators:
_generators = deepcopy(generators)
_generators[:cce] = CCEGenerator(temp=_temp, λ=[l2_λ,λ], opt=opt)
_generators[:energy] = CCE.EnergyDrivenGenerator(λ=[l2_λ,λ], opt=opt)
_generators[:target] = CCE.TargetDrivenGenerator(λ=[l2_λ,λ], opt=opt)
for (gen_name, gen) in _generators
# CCE for given models, λ and generator:
@assert predict_label(M, data, x) != target_label
ce = try
generate_counterfactual(
x, target_label, data, M, gen;
initialization=:identity,
converge_when=:generator_conditions,
gradient_tol=gradient_tol,
max_iter=max_iter,
)
catch
missing
end
if !ismissing(ce)
eval = DataFrame(evaluate(ce, measure=measures, output_format=:Dict))
else
eval = DataFrame(Dict(Symbol(fun) => missing for fun in measures))
end
_results = DataFrame(
dataset = dataname,
model = modelname,
λ = λ,
generator = gen_name,
ce = ce,
factual = factual_label,
target = target_label,
)
_results = crossjoin(_results, eval; makeunique=true)
append!(results, _results)
end
end
end
end
```
```{julia}
#| echo: false
function plot_benchmark(results; dataset=:multi_class, modelname=:cov95, img_height=300, zoom=-0.2)
df_plot = results[results.dataset .== dataset,:] |>
res -> res[res.model .== modelname,:]
plts = map(eachrow(df_plot)) do row
Plots.plot(
row.ce,
title="λ: $(row.λ), gen: $(row.generator)",
cbar=false,
zoom=zoom,
legend=false,
)
end
ncol = length(unique(df_plot.generator))
nrow = length(unique(df_plot.λ))
_layout = (nrow, ncol)
Plots.plot(
plts...,
size=img_height.*reverse(_layout), layout=_layout,
plot_title="dataset: $dataset, model: $modelname",
)
end
```
```{julia}
#| output: true
#| echo: false
df = @pivot_longer(results, distance:distance_from_targets)
for dataname ∈ sort(unique(df.dataset))
Markdown.parse("""### $dataname""")
df_ = df[df.dataset .== dataname, :]
for model in unique(df_.model)
Markdown.parse("""#### model: $model""")
df_plot = df_[df_.model .== model, :]
df_plot = @mutate(df_plot, lambda = string("λ: ", round(λ, digits=2)))
plt = AlgebraOfGraphics.data(df_plot) * visual(BarPlot) *
mapping(:generator, :value, row=:variable, col=:lambda, color=:generator)
plt = draw(
plt, axis=(xlabel="", xticksvisible=false, xticklabelsvisible=false, width=200, height=180),
facet=(; linkyaxes=:minimal)
)
# plt.figure[0, :] = Label(
# plt.figure, "data: $dataname, model: $model",
# fontsize=20, tellwidth=false
# )
display(plt)
end
end
```
### Full Benchmark
```{julia}
bmks = []
for (dataname, dataset) in datasets
for λ in Λ, temp in temps
_generators = deepcopy(generators)
......@@ -201,7 +330,7 @@ CSV.write(joinpath(output_path, "synthetic_benchmark.csv"), bmk())
df = bmk()
for dataname ∈ sort(unique(df.dataname))
Markdown.parse("""### $dataset""")
Markdown.parse("""### $dataname""")
df_ = df[df.dataname .== dataname, :]
for λ in Λ, temp in temps
Markdown.parse("""#### λ: $λ""")
......
......@@ -36,7 +36,7 @@ end
function distance_from_energy(
counterfactual_explanation::AbstractCounterfactualExplanation;
n::Int=100, from_buffer=true, agg=mean, kwargs...
n::Int=10000, from_buffer=true, agg=mean, kwargs...
)
conditional_samples = []
ignore_derivatives() do
......@@ -63,9 +63,10 @@ end
function distance_from_targets(
counterfactual_explanation::AbstractCounterfactualExplanation;
n::Int=100, agg=mean
n::Int=10000, agg=mean
)
target_samples = counterfactual_explanation.data.X |>
target_idx = counterfactual_explanation.data.output_encoder.labels .== counterfactual_explanation.target
target_samples = counterfactual_explanation.data.X[:,target_idx] |>
X -> X[:,rand(1:end,n)]
x′ = CounterfactualExplanations.counterfactual(counterfactual_explanation)
loss = map(eachslice(x′, dims=3)) do x
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment