Skip to content
Snippets Groups Projects
Commit a12f3f05 authored by pat-alt's avatar pat-alt
Browse files
parents 4ff5f10a 5ab6c75e
Branches 38-laplace
No related tags found
No related merge requests found
# Other:
/artifacts/
/.quarto/
/Manifest.toml
/replicated/
# Tex
......
......@@ -3,5 +3,13 @@ git-tree-sha1 = "860ef374887cfbaa8eca835b574092678907e446"
lazy = true
[[artifacts-.download]]
sha256 = "e0f9e32ceb9e70e43fde8dbd4ad3af6a701ac0c25633df259a7b79083fed2ae0"
sha256 = "5ed618dd02df05991e02fe6d8eecf3ac604edafec95e07a85d0e3487baf46e4a"
url = "https://github.com/pat-alt/ECCCo.jl/releases/download/results-paper-submission-1.8.5/artifacts-.tar.gz"
["results-paper-submission-1.8.5"]
git-tree-sha1 = "3be5119c4ce466017db79f10fdb72b97e745bd7d"
lazy = true
[["results-paper-submission-1.8.5".download]]
sha256 = "15c6d44aba4e9860ba2b55b70ce8a2f41216d190a5e951fd6907b8cb616ec215"
url = "https://github.com/pat-alt/ECCCo.jl/releases/download/results-paper-submission-1.8.5/results-paper-submission-1.8.5.tar.gz"
@misc{ECCCo.jl,
author = {Patrick Altmeyer},
author = {Anonymous Author},
title = {ECCCo.jl},
url = {https://github.com/pat-alt/ECCCo.jl},
version = {v0.1.0},
......
This diff is collapsed.
name = "ECCCo"
uuid = "0232c203-4013-4b0d-ad96-43e3e11ac3bf"
authors = ["Patrick Altmeyer"]
authors = ["Anonymous Author"]
version = "0.1.0"
[deps]
......
# ECCCo
[![Build Status](https://github.com/pat-alt/ECCCo.jl/actions/workflows/CI.yml/badge.svg?branch=main)](https://github.com/pat-alt/ECCCo.jl/actions/workflows/CI.yml?query=branch%3Amain)
![](artifacts/results/images/poc_gradient_fields.png)
*Energy-Constrained Counterfactual Explanations.*
This work is currently undergoing peer review. This README is therefore only meant to provide reviewers access to the code base. The code base will be made public after the review process.
## Inspecting the Package Code
This code base is structured as a Julia package. The package code is located in the `src/` folder.
## Inspecting the Code for Experiments
We used [Quarto](https://quarto.org/) notebooks for prototyping and running experiments. The notebooks are located in the `notebooks/` folder, separated by dataset:
- [Linearly Separable](notebooks/linearly_separable.qmd)
- [Moons](notebooks/moons.qmd)
- [Circles](notebooks/circles.qmd)
- [MNIST](notebooks/mnist.qmd)
- [GMSC](notebooks/gmsc.qmd)
## Inspecting the Results
All results have been carefully reported either in the paper itself or in the supplementary material. In addition, we have released our results as binary files. These will be made publicly available after the review process.
## Reproducing the Results
To reproduce the results, you need to install the package, which will automatically install all dependencies. Since the package is not publicly registered and you are looking at an anonymous repository that [cannot be cloned](https://anonymous.4open.science/faq#download), unfortunately, it is not possible to easily install the package and reproduce the results at this stage of the review process.
However, provided that the package is indeed installed, you can reproduce the results by either running the experiments in the `experiments/` folder or using the notebooks listed above for a more interactive process.
**Note**: All experiments were run on `julia-1.8.5`. Since pre-trained models were serialised on that version they may not be compatible with newer versions of Julia.
### Command Line
The `experiments/` folder contains separate Julia scripts for each dataset and a [run_experiments.jl](experiments/run_experiments.jl) that calls the individual scripts. You can either cun these scripts inside a Julia session or just use the command line to execute them as described in the following.
To run the experiment for a single dataset, (e.g. `linearly_separable`) simply run the following command:
```shell
DATANAME=linearly_separable
julia experiments/run_experiments.jl
```
We use the following identifiers:
- `linearly_separable` (*Linearly Separable* data)
- `moons` (*Moons* data)
- `circles` (*Circles* data)
- `mnist` (*MNIST* data)
- `gmsc` (*GMSC* data)
To run all experiments at once you can instead just specify `DATANAME=all`.
Pre-trained versions of all of our black-box models have been archived as `Pkg` [artifacts](https://pkgdocs.julialang.org/v1/artifacts/) and are used by default. Should you wish to retrain the models as well, simply run the following command:
```shell
DATANAME=linearly_separable
RETRAIN=true
julia experiments/run_experiments.jl
```
When running the experiments from the command line, the parameter choices used in the main paper are applied by default. To have control over these choices, we recommend you instead rely on the notebooks.
### Notebooks
To run the notebooks and ensure that all package dependencies are installed, you need to clone this repo and open it on your device. The first cell in each notebook sets up the environment. You may have to [instantiate](https://pkgdocs.julialang.org/v1/api/#Pkg.instantiate) the local environment once. Should you prefer working with Jupyter notebooks instead of Quarto, you can easily [convert](https://quarto.org/docs/tools/vscode-notebook.html#converting-notebooks) them through a single command.
......@@ -6,7 +6,7 @@ project:
book:
title: "Conformal Counterfactual Explanations"
subtitle: "Online Companion"
author: "Patrick Altmeyer"
author: "Anonymous Author"
date: today
chapters:
- index.qmd
......
bib.bib 0 → 100644
This diff is collapsed.
......@@ -9,12 +9,16 @@ artifact_toml = LazyArtifacts.find_artifacts_toml(".")
function generate_artifacts(
datafiles;
artifact_name=nothing,
artifact_name="artifacts-$VERSION",
root=".",
artifact_toml=joinpath(root, "Artifacts.toml"),
deploy=true,
tag="artifacts-$(Int(VERSION.major)).$(Int(VERSION.minor))",
tag=nothing,
)
if isnothing(tag)
tag = replace(lowercase(artifact_name), " " => "-")
end
if deploy && !haskey(ENV, "GITHUB_TOKEN")
@warn "For automatic github deployment, need GITHUB_TOKEN. Not found in ENV, attemptimg global git config."
end
......@@ -72,15 +76,6 @@ function generate_artifacts(
end
end
function create_artifact_name_from_path(
datafiles::String, artifact_name::Union{Nothing,String}
)
# Name for hash/artifact:
artifact_name =
isnothing(artifact_name) ? replace(datafiles, ("/" => "-")) : artifact_name
return artifact_name
end
function get_git_remote_url(repo_path::String=".")
repo = LibGit2.GitRepo(repo_path)
origin = LibGit2.get(LibGit2.GitRemote, repo, "origin")
......
......@@ -6,7 +6,7 @@
<meta name="viewport" content="width=device-width, initial-scale=1.0, user-scalable=yes">
<meta name="author" content="Patrick Altmeyer">
<meta name="author" content="Anonymous Author">
<meta name="dcterms.date" content="2023-04-05">
<title>Conformal Counterfactual Explanations</title>
......@@ -148,7 +148,7 @@ ul.task-list li input[type="checkbox"] {
<div>
<div class="quarto-title-meta-heading">Author</div>
<div class="quarto-title-meta-contents">
<p>Patrick Altmeyer </p>
<p>Anonymous Author </p>
</div>
</div>
......
n_obs = Int(1000 / (1.0 - test_size))
counterfactual_data, test_data = train_test_split(load_circles(n_obs; noise=0.05, factor=0.5); test_size=test_size)
run_experiment(
counterfactual_data, test_data; dataname="Circles",
n_hidden=32,
α=[1.0, 1.0, 1e-2],
sampling_batch_size=nothing,
sampling_steps=20,
λ₁=0.25,
λ₂ = 0.75,
λ₃ = 0.75,
opt=Flux.Optimise.Descent(0.01),
use_class_loss = false,
)
\ No newline at end of file
counterfactual_data, test_data = train_test_split(load_gmsc(nothing); test_size=test_size)
run_experiment(
counterfactual_data, test_data; dataname="GMSC",
n_hidden=128,
activation = Flux.swish,
builder = MLJFlux.@builder Flux.Chain(
Dense(n_in, n_hidden, activation),
Dense(n_hidden, n_hidden, activation),
Dense(n_hidden, n_out),
),
α=[1.0, 1.0, 1e-1],
sampling_batch_size=nothing,
sampling_steps = 30,
use_ensembling = true,
λ₁ = 0.1,
λ₂ = 0.5,
λ₃ = 0.5,
opt = Flux.Optimise.Descent(0.05),
use_class_loss=false,
use_variants=false,
)
\ No newline at end of file
n_obs = Int(1000 / (1.0 - test_size))
counterfactual_data, test_data = train_test_split(
load_blobs(n_obs; cluster_std=0.1, center_box=(-1.0 => 1.0));
test_size=test_size
)
run_experiment(counterfactual_data, test_data; dataname="Linearly Separable")
\ No newline at end of file
function pre_process(x; noise::Float32=0.03f0)
ϵ = Float32.(randn(size(x)) * noise)
x += ϵ
return x
end
# Training data:
n_obs = 10000
counterfactual_data = load_mnist(n_obs)
counterfactual_data.X = pre_process.(counterfactual_data.X)
# VAE (trained on full dataset):
using CounterfactualExplanations.Models: load_mnist_vae
vae = load_mnist_vae()
counterfactual_data.generative_model = vae
# Test data:
test_data = load_mnist_test()
# Generators:
eccco_generator = ECCCoGenerator(
λ=[0.1,0.25,0.25],
temp=0.1,
opt=nothing,
use_class_loss=true,
nsamples=10,
nmin=10,
)
Λ = eccco_generator.λ
generator_dict = Dict(
"Wachter" => WachterGenerator(λ=Λ[1], opt=eccco_generator.opt),
"REVISE" => REVISEGenerator(λ=Λ[1], opt=eccco_generator.opt),
"Schut" => GreedyGenerator(η=2.0),
"ECCCo" => eccco_generator,
)
# Run:
run_experiment(
counterfactual_data, test_data; dataname="MNIST",
n_hidden = 128,
activation = Flux.swish,
builder = MLJFlux.@builder Flux.Chain(
Dense(n_in, n_hidden, activation),
Dense(n_hidden, n_out),
),
𝒟x = Uniform(-1.0, 1.0),
α = [1.0,1.0,1e-2],
sampling_batch_size = 10,
ssampling_steps=25,
use_ensembling = true,
generators = generator_dict,
)
\ No newline at end of file
n_obs = Int(2500 / (1.0 - test_size))
counterfactual_data, test_data = train_test_split(load_moons(n_obs); test_size=test_size)
run_experiment(
counterfactual_data, test_data; dataname="Moons",
epochs=500,
n_hidden=32,
activation = Flux.relu,
α=[1.0, 1.0, 1e-1],
sampling_batch_size=10,
sampling_steps=30,
λ₁=0.25,
λ₂=0.75,
λ₃=0.75,
opt=Flux.Optimise.Descent(0.05),
use_class_loss=false
)
\ No newline at end of file
include("setup.jl")
# User inputs:
if ENV["DATANAME"] == "all"
datanames = ["linearly_separable", "moons", "circles", "mnist", "gmsc"]
else
datanames = [ENV["DATANAME"]]
end
# Linearly Separable
if "linearly_separable" in datanames
@info "Running linearly separable experiment."
include("linearly_separable.jl")
end
# Moons
if "moons" in datanames
@info "Running moons experiment."
include("moons.jl")
end
# Circles
if "circles" in datanames
@info "Running circles experiment."
include("circles.jl")
end
# MNIST
if "mnist" in datanames
@info "Running MNIST experiment."
include("mnist.jl")
end
# GMSC
if "gmsc" in datanames
@info "Running GMSC experiment."
include("gmsc.jl")
end
# General setup:
include("$(pwd())/notebooks/setup.jl")
eval(setup_notebooks)
output_path = "$(pwd())/replicated"
isdir(output_path) || mkdir(output_path)
@info "All results will be saved to $output_path."
params_path = "$(pwd())/replicated/params"
isdir(params_path) || mkdir(params_path)
@info "All parameter choices will be saved to $params_path."
test_size = 0.2
# Artifacts:
using LazyArtifacts
@warn "Models were pre-trained on `julia-1.8.5` and may not work on other versions."
artifact_path = joinpath(artifact"results-paper-submission-1.8.5","results-paper-submission-1.8.5")
pretrained_path = joinpath(artifact_path, "results")
function run_experiment(
counterfactual_data,
test_data;
dataname,
output_path=output_path,
params_path=params_path,
pretrained_path=pretrained_path,
retrain=false,
epochs=100,
n_hidden=16,
activation=Flux.swish,
builder=MLJFlux.MLP(
hidden=(n_hidden, n_hidden, n_hidden),
σ=activation
),
n_ens=5,
𝒟x=Normal(),
sampling_batch_size=50,
α=[1.0, 1.0, 1e-1],
verbosity=10,
sampling_steps=30,
use_ensembling=false,
coverage=.95,
λ₁=0.25,
λ₂ = 0.75,
λ₃ = 0.75,
opt=Flux.Optimise.Descent(0.01),
use_class_loss=false,
use_variants=true,
n_individuals=25,
generators=nothing,
)
# SETUP ----------
# Data
X, y = CounterfactualExplanations.DataPreprocessing.unpack_data(counterfactual_data)
X = table(permutedims(X))
labels = counterfactual_data.output_encoder.labels
input_dim, n_obs = size(counterfactual_data.X)
output_dim = length(unique(labels))
save_name = replace(lowercase(dataname), " " => "_")
# Model parameters:
batch_size = minimum([Int(round(n_obs / 10)), 128])
sampling_batch_size = isnothing(sampling_batch_size) ? batch_size : sampling_batch_size
_loss = Flux.Losses.crossentropy # loss function
_finaliser = Flux.softmax # finaliser function
# JEM parameters:
𝒟y = Categorical(ones(output_dim) ./ output_dim)
sampler = ConditionalSampler(
𝒟x, 𝒟y,
input_size=(input_dim,),
batch_size=sampling_batch_size,
)
# MODELS ----------
# Simple MLP:
mlp = NeuralNetworkClassifier(
builder=builder,
epochs=epochs,
batch_size=batch_size,
finaliser=_finaliser,
loss=_loss,
)
# Deep Ensemble:
mlp_ens = EnsembleModel(model=mlp, n=n_ens)
# Joint Energy Model:
jem = JointEnergyClassifier(
sampler;
builder=builder,
epochs=epochs,
batch_size=batch_size,
finaliser=_finaliser,
loss=_loss,
jem_training_params=(
α=α, verbosity=verbosity,
),
sampling_steps=sampling_steps
)
# Deep Ensemble of Joint Energy Models:
jem_ens = EnsembleModel(model=jem, n=n_ens)
# Dictionary of models:
if !use_ensembling
models = Dict(
"MLP" => mlp,
"JEM" => jem,
)
else
models = Dict(
"MLP" => mlp,
"MLP Ensemble" => mlp_ens,
"JEM" => jem,
"JEM Ensemble" => jem_ens,
)
end
# TRAINING ----------
function _train(model, X=X, y=labels; cov=coverage, method=:simple_inductive, mod_name="model")
conf_model = conformal_model(model; method=method, coverage=cov)
mach = machine(conf_model, X, y)
@info "Begin training $mod_name."
fit!(mach)
@info "Finished training $mod_name."
M = ECCCo.ConformalModel(mach.model, mach.fitresult)
return M
end
if retrain
@info "Retraining models."
model_dict = Dict(mod_name => _train(model; mod_name=mod_name) for (mod_name, model) in models)
Serialization.serialize(joinpath(output_path, "$(save_name)_models.jls"), model_dict)
else
@info "Loading pre-trained models."
model_dict = Serialization.deserialize(joinpath(pretrained_path, "$(save_name)_models.jls"))
end
params = DataFrame(
Dict(
:n_obs => Int.(round(n_obs/10)*10),
:epochs => epochs,
:batch_size => batch_size,
:n_hidden => n_hidden,
:n_layers => length(model_dict["MLP"].fitresult[1][1])-1,
:activation => string(activation),
:n_ens => n_ens,
:lambda => string(α[3]),
:jem_sampling_steps => jem.sampling_steps,
:sgld_batch_size => sampler.batch_size,
:dataname => dataname,
)
)
CSV.write(joinpath(params_path, "$(save_name)_model_params.csv"), params)
measure = Dict(
:f1score => multiclass_f1score,
:acc => accuracy,
:precision => multiclass_precision
)
model_performance = DataFrame()
for (mod_name, model) in model_dict
# Test performance:
_perf = CounterfactualExplanations.Models.model_evaluation(model, test_data, measure=collect(values(measure)))
_perf = DataFrame([[p] for p in _perf], collect(keys(measure)))
_perf.mod_name .= mod_name
_perf.dataname .= dataname
model_performance = vcat(model_performance, _perf)
end
Serialization.serialize(joinpath(output_path, "$(save_name)_model_performance.jls"), model_performance)
CSV.write(joinpath(output_path, "$(save_name)_model_performance.csv"), model_performance)
@info "Model performance:"
println(model_performance)
# COUNTERFACTUALS ----------
@info "Begin benchmarking counterfactual explanations."
Λ = [λ₁, λ₂, λ₃]
generator_params = DataFrame(
Dict(
:λ1 => λ₁,
:λ2 => λ₂,
:λ3 => λ₃,
:opt => string(typeof(opt)),
:eta => opt.eta,
:dataname => dataname,
)
)
CSV.write(joinpath(params_path, "$(save_name)_generator_params.csv"), generator_params)
# Benchmark generators:
if !isnothing(generators)
generator_dict = generators
elseif use_variants
generator_dict = Dict(
"Wachter" => WachterGenerator(λ=λ₁, opt=opt),
"REVISE" => REVISEGenerator(λ=λ₁, opt=opt),
"Schut" => GreedyGenerator(),
"ECCCo" => ECCCoGenerator(λ=Λ, opt=opt, use_class_loss=use_class_loss),
"ECCCo (no CP)" => ECCCoGenerator(λ=[λ₁, 0.0, λ₃], opt=opt, use_class_loss=use_class_loss),
"ECCCo (no EBM)" => ECCCoGenerator(λ=[λ₁, λ₂, 0.0], opt=opt, use_class_loss=use_class_loss),
)
else
generator_dict = Dict(
"Wachter" => WachterGenerator(λ=λ₁, opt=opt),
"REVISE" => REVISEGenerator(λ=λ₁, opt=opt),
"Schut" => GreedyGenerator(),
"ECCCo" => ECCCoGenerator(λ=Λ, opt=opt, use_class_loss=use_class_loss),
)
end
# Measures:
measures = [
CounterfactualExplanations.distance,
ECCCo.distance_from_energy,
ECCCo.distance_from_targets,
CounterfactualExplanations.Evaluation.validity,
CounterfactualExplanations.Evaluation.redundancy,
ECCCo.set_size_penalty
]
bmks = []
for target in sort(unique(labels))
for factual in sort(unique(labels))
if factual == target
continue
end
bmk = benchmark(
counterfactual_data;
models=model_dict,
generators=generator_dict,
measure=measures,
suppress_training=true, dataname=dataname,
n_individuals=n_individuals,
target=target, factual=factual,
initialization=:identity,
converge_when=:generator_conditions
)
push!(bmks, bmk)
end
end
bmk = reduce(vcat, bmks)
CSV.write(joinpath(output_path, "$(save_name)_benchmark.csv"), bmk())
end
\ No newline at end of file
......@@ -6,7 +6,7 @@
<meta name="viewport" content="width=device-width, initial-scale=1.0, user-scalable=yes">
<meta name="author" content="Patrick Altmeyer">
<meta name="author" content="Anonymous Author">
<meta name="dcterms.date" content="2023-04-06">
<title>Conformal Counterfactual Explanations</title>
......@@ -148,7 +148,7 @@ ul.task-list li input[type="checkbox"] {
<div>
<div class="quarto-title-meta-heading">Author</div>
<div class="quarto-title-meta-contents">
<p>Patrick Altmeyer </p>
<p>Anonymous Author </p>
</div>
</div>
......
This diff is collapsed.
......@@ -16,6 +16,7 @@ ECCCo = "0232c203-4013-4b0d-ad96-43e3e11ac3bf"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
Images = "916415d5-f1e6-5110-898d-aaa5f9f070e0"
JointEnergyModels = "48c56d24-211d-4463-bbc0-7a701b291131"
LazyArtifacts = "4af54fe1-eca0-43a8-85a7-787d91b784e3"
MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458"
MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
MLJEnsembles = "50ed68f4-41fd-4504-931a-ed422449fee0"
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment