Skip to content
Snippets Groups Projects
Commit afd37112 authored by pat-alt's avatar pat-alt
Browse files

most charts updates

parent 238e1cf4
No related branches found
No related tags found
1 merge request!91Camera ready
......@@ -2385,7 +2385,7 @@ version = "1.11.1"
[[deps.TaijaPlotting]]
deps = ["CategoricalArrays", "ConformalPrediction", "CounterfactualExplanations", "DataAPI", "Distributions", "Flux", "LaplaceRedux", "LinearAlgebra", "MLJBase", "ManifoldLearning", "MultivariateStats", "NaturalSort", "NearestNeighborModels", "Plots"]
path = "../../TaijaPlotting.jl"
git-tree-sha1 = "d54b5bee7ae6e0bee3b3a13dfd662fa090e0b445"
uuid = "bd7198b4-c7d6-400c-9bab-9a24614b0240"
version = "1.0.5"
......
```{julia}
using Pkg; Pkg.activate("experiments")
include("$(pwd())/experiments/setup_env.jl")
using TaijaPlotting
Plots.theme(:wong)
```
## Comparison of Generators
```{julia}
outcome = Serialization.deserialize("results_extra/mnist_outcome.jls")
```
\ No newline at end of file
```{julia}
using Pkg; Pkg.activate("experiments")
include("$(pwd())/experiments/setup_env.jl")
include("$(pwd())/experiments/notebooks/setup.jl")
```
## Motivation
### Wachter and JSMA
```{julia}
Random.seed!(2023)
# Data:
counterfactual_data = load_mnist()
X, y = CounterfactualExplanations.DataPreprocessing.unpack_data(counterfactual_data)
input_dim, n_obs = size(counterfactual_data.X)
M = load_mnist_mlp()
# Target:
factual_label = 9
x_factual = reshape(X[:,rand(findall(predict_label(M, counterfactual_data).==factual_label))],input_dim,1)
target = 7
factual = predict_label(M, counterfactual_data, x_factual)[1]
γ = 0.90
# Training params:
T = 100
```
```{julia}
# Search:
generic_generator = WachterGenerator(opt=Flux.Adam(0.25))
ce_wachter = generate_counterfactual(
x_factual, target, counterfactual_data, M, generic_generator;
decision_threshold=γ, max_iter=T,
initialization=:identity,
)
greedy_generator = GreedyGenerator(η=2.0)
ce_jsma = generate_counterfactual(
x_factual, target, counterfactual_data, M, greedy_generator;
decision_threshold=γ, max_iter=T,
initialization=:identity,
)
```
```{julia}
p1 = Plots.plot(
convert2image(MNIST, reshape(x_factual,28,28)),
axis=([], false),
size=(img_height, img_height),
title="Factual"
)
plts = [p1]
ces = zip([ce_wachter,ce_jsma])
counterfactuals = reduce((x,y)->cat(x,y,dims=3),map(ce -> CounterfactualExplanations.counterfactual(ce[1]), ces))
phat = reduce((x,y) -> cat(x,y,dims=3), map(ce -> target_probs(ce[1]), ces))
for x in zip(eachslice(counterfactuals; dims=3), eachslice(phat; dims=3), ["Wachter","JSMA"])
ce, _phat, _name = (x[1],x[2],x[3])
_title = "$(_name) (p=$(round(_phat[1]; digits=2)))"
plt = Plots.plot(
convert2image(MNIST, reshape(ce,28,28)),
axis=([], false),
size=(img_height, img_height),
title=_title
)
plts = [plts..., plt]
end
plt = Plots.plot(plts...; size=(img_height*length(plts),img_height), layout=(1,length(plts)))
display(plt)
savefig(plt, joinpath(output_images_path, "you_may_not_like_it.png"))
```
#### REVISE
```{julia}
using CounterfactualExplanations.Models: load_mnist_vae
vae = load_mnist_vae()
vae_weak = load_mnist_vae(;strong=false)
```
```{julia}
# Define generator:
revise_generator = REVISEGenerator(
opt = Flux.Optimise.Adam(0.25),
λ=0.1,
)
# Generate recourse:
counterfactual_data.generative_model = vae # assign generative model
ce_strong = generate_counterfactual(
x_factual, target, counterfactual_data, M, revise_generator;
decision_threshold=γ, max_iter=T,
initialization=:identity,
converge_when=:generator_conditions,
)
counterfactual_data_weak = deepcopy(counterfactual_data)
counterfactual_data_weak.generative_model = vae_weak
ce_weak = generate_counterfactual(
x_factual, target, counterfactual_data_weak, M, revise_generator;
decision_threshold=γ, max_iter=T,
initialization=:identity,
converge_when=:generator_conditions,
)
```
```{julia}
ces = zip([ce_strong,ce_weak])
counterfactuals = reduce((x,y)->cat(x,y,dims=3),map(ce -> CounterfactualExplanations.counterfactual(ce[1]), ces))
phat = reduce((x,y) -> cat(x,y,dims=3), map(ce -> target_probs(ce[1]), ces))
plts = [p1]
for x in zip(eachslice(counterfactuals; dims=3), eachslice(phat; dims=3), ["Strong VAE","Weak VAE"])
ce, _phat, _name = (x[1],x[2],x[3])
_title = "$(_name) (p=$(round(_phat[1]; digits=2)))"
plt = Plots.plot(
convert2image(MNIST, reshape(ce,28,28)),
axis=([], false),
size=(img_height, img_height),
title=_title
)
plts = [plts..., plt]
end
plt = Plots.plot(plts...; size=(img_height*length(plts),img_height), layout=(1,length(plts)))
display(plt)
savefig(plt, joinpath(output_images_path, "surrogate_gone_wrong.png"))
```
```{julia}
ces = zip([ce_wachter, ce_jsma, ce_strong])
counterfactuals = reduce((x,y)->cat(x,y,dims=3),map(ce -> CounterfactualExplanations.counterfactual(ce[1]), ces))
phat = reduce((x,y) -> cat(x,y,dims=3), map(ce -> target_probs(ce[1]), ces))
plts = [p1]
for x in zip(eachslice(counterfactuals; dims=3), eachslice(phat; dims=3), ["Wachter","Schut","REVISE"])
ce, _phat, _name = (x[1],x[2],x[3])
_title = "$(_name) (p=$(round(_phat[1]; digits=2)))"
plt = Plots.plot(
convert2image(MNIST, reshape(ce,28,28)),
axis=([], false),
size=(img_height, img_height),
title=_title
)
plts = [plts..., plt]
end
plt = Plots.plot(plts...; size=(0.8*panel_height*length(plts),0.8*panel_height), layout=(1,length(plts)), dpi=400)
display(plt)
savefig(plt, joinpath(paper_figpath, "mnist_motivation.png"))
```
## Comparison with ECCCo
```{julia}
outcome = Serialization.deserialize("results_extra/mnist_outcome.jls")
```
\ No newline at end of file
```{julia}
using Pkg; Pkg.activate("experiments")
include("$(pwd())/experiments/setup_env.jl")
using TaijaPlotting
Plots.theme(:wong)
include("$(pwd())/experiments/notebooks/setup.jl")
```
## Data
......@@ -112,7 +111,7 @@ p0 = Plots.contourf(mach.model, mach.fitresult, permutedims(X), labels; plot_set
p1 = Plots.contourf(mach.model, mach.fitresult, permutedims(X), labels; plot_set_loss=true, zoom=0, temp=temp)
p2 = Plots.contourf(mach.model, mach.fitresult, permutedims(X), labels; plot_classification_loss=true, zoom=0, temp=temp, clim=nothing, loss_matrix=ones(2,2))
plt = display(Plots.plot(p0, p1, p2, size=(1400,320), layout=(1,3)))
# savefig(joinpath(output_images_path, "poc_set_size.png"))
savefig(joinpath(output_images_path, "poc_set_size.png"))
```
## Counterfactuals
......@@ -167,20 +166,22 @@ for (name, generator) in generator_dict
end
plt = Plots.plot(plts..., size=(500,520))
display(plt)
savefig(plt, joinpath(output_images_path, "poc.png"))
```
```{julia}
using Colors
panel_height = 200
col_pal = palette(:seaborn_colorblind)
Random.seed!(1234)
Random.seed!(1234) # 1234
using CounterfactualExplanations.Generators: ∇
λ₁ = 0.2
λ₂ = 1.5
λ₂ = 2.0
λ₃ = 3.0
Λ = [λ₁, λ₂, λ₃]
η = 0.01
max_iter = 100
M = ECCCo.ConformalModel(mach.model, mach.fitresult)
factual_label = levels(labels)[2]
......@@ -191,7 +192,7 @@ factual = predict_label(M, counterfactual_data, x_factual)[1]
opt = Flux.Optimise.Descent(η)
generator_dict = OrderedDict(
"Wachter" => WachterGenerator(λ = 0.7, opt=opt),
"Wachter" => WachterGenerator(λ = 1.0, opt=opt),
"ECCCo (no EBM)" => ECCCoGenerator(λ = [λ₁,λ₂,0.0], opt=opt),
"ECCCo (no CP)" => ECCCoGenerator(λ = [λ₁,0.0,λ₃], opt=opt),
"ECCCo" => ECCCoGenerator(λ = Λ, opt=opt),
......@@ -237,6 +238,7 @@ for (name, generator) in generator_dict
x_factual, target, counterfactual_data, M, generator;
initialization=:identity,
converge_when=:generator_conditions,
max_iter=max_iter,
)
# Main plot (path):
......@@ -270,7 +272,7 @@ for (name, generator) in generator_dict
ces[name] = ce
end
plt = Plots.plot(plts...; size=(panel_height*length(plts),panel_height), layout=(1,length(plts)), dpi=300)
# plt = Plots.plot(plts..., size=(1000,250), layout=(1,4), dpi=300)
display(plt)
savefig(plt, joinpath(paper_figpath, "poc_gradient_fields.png"))
```
using TaijaPlotting
Plots.theme(:wong)
img_height = 300
panel_height = 250
paper_figpath = "paper/camera-ready/figures"
output_images_path = "www/"
\ No newline at end of file
No preview for this file type
......@@ -217,7 +217,7 @@ To assess counterfactuals with respect to Definition~\ref{def:faithful}, we need
where $\mathbf{r}_j \sim \mathcal{N}(\mathbf{0},\mathbf{I})$ is the stochastic term and the step-size $\epsilon_j$ is typically polynomially decayed~\citep{welling2011bayesian}. The term $\mathcal{E}_{\theta}(\mathbf{x}_j|\mathbf{y}^+)$ denotes the model energy conditioned on the target class label $\mathbf{y}^+$ which we specify as the negative logit corresponding to the target class label $\mathbf{y}^{+}$. To allow for faster sampling, we follow the common practice of choosing the step-size $\epsilon_j$ and the standard deviation of $\mathbf{r}_j$ separately. While $\mathbf{x}_J$ is only guaranteed to distribute as $p_{\theta}(\mathbf{x}|\mathbf{y}^{+})$ if $\epsilon \rightarrow 0$ and $J \rightarrow \infty$, the bias introduced for a small finite $\epsilon$ is negligible in practice \citep{murphy2023probabilistic}.
Generating multiple samples using SGLD thus yields an empirical distribution $\widehat{\mathbf{X}}_{\theta,\mathbf{y}^+}$ that approximates what the model has learned about the input data. While in the context of EBM, this is usually done during training, we propose to repurpose this approach during inference in order to evaluate the faithfulness of model explanations. The appendix provides additional implementation details for any tasks related to energy-based modeling.
Generating multiple samples using SGLD thus yields an empirical distribution $\widehat{\mathbf{X}}_{\theta,\mathbf{y}^+}$ that approximates what the model has learned about the input data. While in the context of EBM, this is usually done during training, we propose to repurpose this approach during inference in order to evaluate the faithfulness of model explanations. The appendix provides additional implementation details for any tasks related to energy-based modelling.
\subsection{Quantifying the Model's Predictive Uncertainty}
......@@ -311,7 +311,7 @@ Concerning feature autoencoding ($f: \mathcal{Z} \mapsto \mathcal{X}$), \textit{
Figure~\ref{fig:poc} illustrates how the different components in Equation~\ref{eq:eccco} affect the counterfactual search for a synthetic dataset. The underlying classifier is a Joint Energy Model (\textit{JEM}) that was trained to predict the output class (blue or orange) and generate class-conditional samples~\citep{grathwohl2020your}. We have used four different generator flavours to produce a counterfactual in the blue class for a sample from the orange class: \textit{Wachter}, which only uses the first penalty ($\lambda_2=\lambda_3=0$); \textit{ECCCo (no EBM)}, which does not constrain energy ($\lambda_2=0$); \textit{ECCCo (no CP)}, which involves no set size penalty ($\lambda_3=0$); and, finally, \textit{ECCCo}, which involves all penalties defined in Equation~\ref{eq:eccco}. Arrows indicate (negative) gradients with respect to the objective function at different points in the feature space.
While \textit{Wachter} generates a valid counterfactual, it ends up close to the original starting point consistent with its objective. \textit{ECCCo (no EBM)} pushes the counterfactual further into the target domain to minimize predictive uncertainty, but the outcome is still not plausible. The counterfactual produced by \textit{ECCCo (no CP)} is energy-constrained. Since the \textit{JEM} has learned the conditional input distribution reasonably well in this case, the counterfactuals are both faithful and plausible. Finally, the outcome for \textit{ECCCo} looks similar, but the additional smooth set size penalty leads to somewhat faster convergence.
While \textit{Wachter} generates a valid counterfactual, it ends up close to the original starting point consistent with its objective. \textit{ECCCo (no EBM)} avoids regions of high predictive uncertainty near the decision boundary, but the outcome is still not plausible. The counterfactual produced by \textit{ECCCo (no CP)} is energy-constrained. Since the \textit{JEM} has learned the conditional input distribution reasonably well in this case, the counterfactual is both faithful and plausible. Finally, the outcome for \textit{ECCCo} looks similar, but the additional smooth set size penalty leads to somewhat faster convergence.
\section{Empirical Analysis}\label{emp}
......@@ -467,7 +467,7 @@ This work leverages ideas from energy-based modelling and conformal prediction i
\section*{Acknowledgements}
Some of the members of TU Delft were partially funded by ICAI AI for Fintech Research, an ING---TU Delft
collaboration. Research reported in this work was partially or completely facilitated by computational resources and support of the DelftBlue~\citep{DHPC2022} and the Delft AI Cluster (DAIC: https://doc.daic.tudelft.nl/) at TU Delft. The authors would like to thank Azza Ahmed, in particular, for her tremendous help with running Julia jobs on the cluster. The work remains the sole responsibility of the authors.
collaboration. Research reported in this work was partially or completely facilitated by computational resources and support of the DelftBlue~\citep{DHPC2022} and the Delft AI Cluster (DAIC: https://doc.daic.tudelft.nl/) at TU Delft. Detailed information about the utilized computing resources can be found in the appendix. The authors would like to thank Azza Ahmed, in particular, for her tremendous help with running Julia jobs on the cluster. The work remains the sole responsibility of the authors.
\bibliography{aaai24,bib}
......
......@@ -3231,6 +3231,17 @@
year = {2022},
}
@Article{lecun1998gradient,
author = {LeCun, Yann and Bottou, L{\'e}on and Bengio, Yoshua and Haffner, Patrick},
title = {Gradient-based learning applied to document recognition},
number = {11},
pages = {2278--2324},
volume = {86},
journal = {Proceedings of the IEEE},
publisher = {Ieee},
year = {1998},
}
@Comment{jabref-meta: databaseType:biblatex;}
@Comment{jabref-meta: keypatterndefault:[auth:lower][year][veryshorttitle:lower];}
......
paper/camera-ready/figures/mnist_motivation.png

52.6 KiB | W: | H:

paper/camera-ready/figures/mnist_motivation.png

53.9 KiB | W: | H:

paper/camera-ready/figures/mnist_motivation.png
paper/camera-ready/figures/mnist_motivation.png
paper/camera-ready/figures/mnist_motivation.png
paper/camera-ready/figures/mnist_motivation.png
  • 2-up
  • Swipe
  • Onion skin
paper/camera-ready/figures/poc_gradient_fields.png

578 KiB | W: | H:

paper/camera-ready/figures/poc_gradient_fields.png

793 KiB | W: | H:

paper/camera-ready/figures/poc_gradient_fields.png
paper/camera-ready/figures/poc_gradient_fields.png
paper/camera-ready/figures/poc_gradient_fields.png
paper/camera-ready/figures/poc_gradient_fields.png
  • 2-up
  • Swipe
  • Onion skin
www/poc.png

178 KiB

www/poc_set_size.png

210 KiB | W: | H:

www/poc_set_size.png

218 KiB | W: | H:

www/poc_set_size.png
www/poc_set_size.png
www/poc_set_size.png
www/poc_set_size.png
  • 2-up
  • Swipe
  • Onion skin
www/surrogate_gone_wrong.png

14.8 KiB | W: | H:

www/surrogate_gone_wrong.png

12.4 KiB | W: | H:

www/surrogate_gone_wrong.png
www/surrogate_gone_wrong.png
www/surrogate_gone_wrong.png
www/surrogate_gone_wrong.png
  • 2-up
  • Swipe
  • Onion skin
www/you_may_not_like_it.png

12.5 KiB | W: | H:

www/you_may_not_like_it.png

11 KiB | W: | H:

www/you_may_not_like_it.png
www/you_may_not_like_it.png
www/you_may_not_like_it.png
www/you_may_not_like_it.png
  • 2-up
  • Swipe
  • Onion skin
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment