Skip to content
Snippets Groups Projects
Commit e6a88ab7 authored by Pat Alt's avatar Pat Alt
Browse files

more work on synthetic

parent 5271ced9
No related branches found
No related tags found
No related merge requests found
......@@ -12,6 +12,7 @@ book:
- index.qmd
- notebooks/proposal.qmd
- notebooks/intro.qmd
- notebooks/synthetic.qmd
- notebooks/references.qmd
bibliography: bib.bib
......
This diff is collapsed.
......@@ -2,7 +2,7 @@
julia_version = "1.8.5"
manifest_format = "2.0"
project_hash = "fcc7af85fdd632057c2b8e76e581f46c71876dbf"
project_hash = "028f8f6a9b7bdf90783ab0490b56543ce7f8c8d6"
[[deps.AbstractFFTs]]
deps = ["ChainRulesCore", "LinearAlgebra"]
......
[deps]
AlgebraOfGraphics = "cbdf2221-f076-402e-a563-3d30da359d67"
CCE = "0232c203-4013-4b0d-ad96-43e3e11ac3bf"
CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b"
CairoMakie = "13f3f980-e62b-5c42-98c6-ff1f3baf88f0"
CategoricalDistributions = "af321ab8-2d2e-40a6-b165-3d674595d28e"
Chain = "8be319e6-bccf-4806-a6f7-6fae938471bc"
......
......@@ -12,13 +12,16 @@ setup_notebooks = quote
using ConformalPrediction
using CounterfactualExplanations
using CounterfactualExplanations.Data
using CounterfactualExplanations.Evaluation: benchmark
using CounterfactualExplanations.Models: load_mnist_mlp
using CounterfactualExplanations.Objectives
using CSV
using Distributions
using Flux
using Images
using JointEnergyModels
using LinearAlgebra
using Markdown
using MLDatasets
using MLDatasets: convert2image
using MLJBase
......@@ -34,7 +37,7 @@ setup_notebooks = quote
Plots.theme(:wong)
Random.seed!(2023)
www_path = "www"
output_path = "artifacts"
output_path = "artifacts/results"
img_height = 300
end;
\ No newline at end of file
......@@ -44,6 +44,8 @@ opt = Descent(0.01)
```
```{julia}
#| echo: false
results = DataFrame()
for (dataname, data) in datasets
......@@ -101,6 +103,8 @@ end
```
```{julia}
#| echo: false
function plot_ce(results; dataset=:multi_class, classifier=:mlp, λ=0.1, img_height=300, zoom=-0.2)
df_plot = results[results.dataset .== dataset,:] |>
res -> res[res.classifier .== classifier,:] |>
......@@ -126,5 +130,90 @@ end
```
```{julia}
plot_ce(results; λ=Λ[1], dataset=:moons)
#| output: true
#| echo: false
for dataset in keys(datasets)
Markdown.parse("""### $dataset""")
for classifier in keys(classifiers)
Markdown.parse("""#### $classifier""")
Markdown.parse("""::: {.panel-tabset}""")
for λ in Λ
Markdown.parse("""##### λ: $λ""")
display(plot_ce(results; dataset=dataset, classifier=classifier, λ=λ))
end
Markdown.parse(""":::""")
end
end
```
## Benchmark
```{julia}
generators = Dict(
:wachter => GenericGenerator(opt=opt, λ=l2_λ),
:revise => REVISEGenerator(opt=opt, λ=l2_λ),
:greedy => GreedyGenerator(opt=opt),
)
# Untrained Models:
models = Dict("cov$(cov)" => CCE.ConformalModel(conformal_model(mlp; method=:simple_inductive, coverage=cov)) for cov in cvgs)
```
```{julia}
bmks = []
measures = [
CounterfactualExplanations.distance,
CCE.distance_from_energy,
CCE.distance_from_targets,
CounterfactualExplanations.validity,
]
for (dataname, dataset) in datasets
for λ in Λ, temp in temps
_generators = deepcopy(generators)
_generators[:cce] = CCEGenerator(temp=temp, λ=[l2_λ,λ], opt=opt)
_generators[:energy] = CCE.EnergyDrivenGenerator(λ=[l2_λ,λ], opt=opt)
_generators[:target] = CCE.TargetDrivenGenerator(λ=[l2_λ,λ], opt=opt)
bmk = benchmark(
dataset;
models=deepcopy(models),
generators=_generators,
measure=measures,
suppress_training=false, dataname=dataname,
n_individuals=5,
initialization=:identity,
)
bmk.evaluation.λ .= λ
bmk.evaluation.temperature .= temp
push!(bmks, bmk)
end
end
bmk = reduce(vcat, bmks)
```
```{julia}
CSV.write(joinpath(output_path, "synthetic_benchmark.csv"), bmk())
```
```{julia}
#| output: true
#| echo: false
df = bmk()
for dataname ∈ sort(unique(df.dataname))
Markdown.parse("""### $dataset""")
df_ = df[df.dataname .== dataname, :]
for λ in Λ, temp in temps
Markdown.parse("""#### λ: $λ""")
df_plot = df_[df_.λ .== λ, :]
df_plot = df_plot[df_plot.temperature .== temp, :]
plt = AlgebraOfGraphics.data(df_plot) * visual(BoxPlot) *
mapping(:generator, :value, row=:variable, col=:model, color=:generator)
plt = draw(
plt, axis=(xlabel="", xticksvisible=false, xticklabelsvisible=false, width=200, height=180),
facet=(; linkyaxes=:minimal)
)
display(plt)
end
end
```
......@@ -8,4 +8,18 @@ function CCEGenerator(; λ::Union{AbstractFloat,Vector{<:AbstractFloat}}=[0.1, 1
_penalties = [Objectives.distance_l2, _set_size_penalty]
λ = λ isa AbstractFloat ? [0.0, λ] : λ
return Generator(; penalty=_penalties, λ=λ, kwargs...)
end
"Constructor for `EnergyDrivenGenerator`."
function EnergyDrivenGenerator(; λ::Union{AbstractFloat,Vector{<:AbstractFloat}}=[0.1, 1.0], kwargs...)
_penalties = [Objectives.distance_l2, CCE.distance_from_energy]
λ = λ isa AbstractFloat ? [0.0, λ] : λ
return Generator(; penalty=_penalties, λ=λ, kwargs...)
end
"Constructor for `TargetDrivenGenerator`."
function TargetDrivenGenerator(; λ::Union{AbstractFloat,Vector{<:AbstractFloat}}=[0.1, 1.0], kwargs...)
_penalties = [Objectives.distance_l2, CCE.distance_from_targets]
λ = λ isa AbstractFloat ? [0.0, λ] : λ
return Generator(; penalty=_penalties, λ=λ, kwargs...)
end
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment