Skip to content
Snippets Groups Projects
Commit e4f42b65 authored by Pat Alt's avatar Pat Alt
Browse files
parents 7b9f64ae ba3abd14
No related branches found
No related tags found
No related merge requests found
artifacts/results/images/mnist_eccco.png

18.2 KiB | W: | H:

artifacts/results/images/mnist_eccco.png

18.2 KiB | W: | H:

artifacts/results/images/mnist_eccco.png
artifacts/results/images/mnist_eccco.png
artifacts/results/images/mnist_eccco.png
artifacts/results/images/mnist_eccco.png
  • 2-up
  • Swipe
  • Onion skin
No preview for this file type
No preview for this file type
......@@ -137,7 +137,7 @@ savefig(plt, joinpath(output_images_path, "surrogate_gone_wrong.png"))
```{julia}
function pre_process(x; noise::Float32=0.03f0)
ϵ = Float32.(randn(size(x)) * noise)
# x = @.(2 * x - 1)
x = @.(2 * x - 1)
x += ϵ
return x
end
......@@ -145,8 +145,8 @@ end
```{julia}
# Hyper:
_retrain = true
_regen = true
_retrain = false
_regen = false
# Data:
n_obs = 10000
......@@ -307,7 +307,7 @@ model_performance = DataFrame()
for (mod_name, mod) in model_dict
# Test performance:
test_data = load_mnist_test()
# test_data.X = pre_process.(test_data.X, noise=0.0f0)
test_data.X = pre_process.(test_data.X, noise=0.0f0)
_perf = CounterfactualExplanations.Models.model_evaluation(mod, test_data, measure=collect(values(measure)))
_perf = DataFrame([[p] for p in _perf], collect(keys(measure)))
_perf.mod_name .= mod_name
......@@ -321,6 +321,8 @@ model_performance
### Different Models
```{julia}
plt_order = ["MLP", "MLP Ensemble", "JEM", "JEM Ensemble"]
# ECCCo:
λ=[0.5,0.1,0.5]
temp=0.5
......@@ -343,7 +345,7 @@ for (mod_name, mod) in model_dict
)
ces[mod_name] = ce
end
plt_order = sortperm(collect(keys(ces)))
_plt_order = map(x -> findall(collect(keys(model_dict)) .== x)[1], plt_order)
# Plot:
p1 = Plots.plot(
......@@ -366,7 +368,7 @@ for (_name,ce) in ces
)
plts = [plts..., plt]
end
plts = plts[plt_order]
plts = plts[_plt_order]
plts = [p1, plts...]
plt = Plots.plot(plts...; size=(img_height*length(plts),img_height), layout=(1,length(plts)))
display(plt)
......
No preview for this file type
......@@ -238,8 +238,6 @@ where $\hat{\mathbf{x}}_{\theta}$ denotes samples generated using SGLD (Equation
The first two terms in Equation~\ref{eq:eccco} correspond to the counterfactual search objective defined in~\citet{wachter2017counterfactual} which merely penalises the distance of counterfactuals from their factual values. The additional two penalties in ECCCo ensure that counterfactuals conform with the model's generative property and lead to minimally uncertain predictions, respectively. The hyperparameters $\lambda_1, ..., \lambda_3$ can be used to balance the different objectives: for example, we may choose to incur larger deviations from the factual in favour of conformity with the model's generative property by choosing lower values of $\lambda_1$ and relatively higher values of $\lambda_2$. Figure~\ref{fig:eccco} illustrates this balancing act for an example involving synthetic data: vector fields indicate the direction of gradients with respect to the different components our proposed objective function (Equation~\ref{eq:eccco}).
The entire procedure for Generating ECCCos is described in Algorithm~\ref{alg:eccco}. For the sake of simplicity and without loss of generality, we limit our attention to generating a single counterfactual $\mathbf{x}^\prime=f(\mathbf{z}^\prime)$ where in contrast to Equation~\ref{eq:eccco} $\mathbf{z}^\prime$ denotes a $1$-dimensional array containing a single counterfactual state. That state is initialized by passing the factual $\mathbf{x}$ through the encoder $f^{-1}$ which in our case corresponds to a simple feature transformer, rather than the encoder part of VAE as in REVISE~\citep{joshi2019realistic}.
\medskip
\renewcommand{\algorithmicrequire}{\textbf{Input:}}
......@@ -254,14 +252,14 @@ The entire procedure for Generating ECCCos is described in Algorithm~\ref{alg:ec
\begin{minipage}[c]{0.50\textwidth}
\captionof{algorithm}{Generating ECCCos (For more details, see Appendix~\ref{app:eccco})}\label{alg:eccco}
\begin{algorithmic}[1]
\Require $\mathbf{x}, \mathbf{y}^*, M_{\theta}, f, \Lambda, \alpha, \mathcal{D}, T, \eta, m, M$ \linebreak where $M_{\theta}(\mathbf{x})\neq\mathbf{y}^*$
\Require $\mathbf{x}, \mathbf{y}^*, M_{\theta}, f, \Lambda, \alpha, \mathcal{D}, T, \eta, n_{\mathcal{B}}, N_{\mathcal{B}}$ \linebreak where $M_{\theta}(\mathbf{x})\neq\mathbf{y}^*$
\Ensure $\mathbf{x}^\prime$
\State Initialize $\mathbf{z}^\prime \gets f^{-1}(\mathbf{x})$
\State Generate buffer $\mathcal{B}$ of $M$ conditional samples $\hat{\mathbf{x}}_{\theta}|\mathbf{y}^*$ using SGLD (Equation~\ref{eq:sgld})
\State Generate buffer $\mathcal{B}$ of $N_{\mathcal{B}}$ conditional samples $\hat{\mathbf{x}}_{\theta}|\mathbf{y}^*$ using SGLD (Equation~\ref{eq:sgld})
\State Run \textit{SCP} for $M_{\theta}$ using $\mathcal{D}$
\State Initialize $t \gets 0$
\While{\textit{not converged} or $t < T$}
\State $\hat{\mathbf{x}}_{\theta, t} \gets \text{rand}(\mathcal{B},m)$
\State $\hat{\mathbf{x}}_{\theta, t} \gets \text{rand}(\mathcal{B},n_{\mathcal{B}})$
\State $\mathbf{z}^\prime \gets \mathbf{z}^\prime - \eta \nabla_{\mathbf{z}^\prime} \mathcal{L}(\mathbf{z}^\prime,\mathbf{y}^*,\hat{\mathbf{x}}_{\theta, t})$
\State $t \gets t+1$
\EndWhile
......@@ -271,7 +269,20 @@ The entire procedure for Generating ECCCos is described in Algorithm~\ref{alg:ec
\medskip
\section{Evaluation Framework}\label{conformity}
\begin{minipage}[c]{\textwidth}
\includegraphics[width=\textwidth]{../artifacts/results/images/mnist_eccco.png}
\captionof{figure}{Original image (left) and ECCCos for turning an 8 (eight) into a 3 (three) for different Black Boxes from left to right: Multi-Layer Perceptron (MLP), Ensemble of MLPs, Joint Energy Model (JEM), Ensemble of JEMs.}\label{fig:eccco-mnist}
\end{minipage}
\medskip
The entire procedure for Generating ECCCos is described in Algorithm~\ref{alg:eccco}. For the sake of simplicity and without loss of generality, we limit our attention to generating a single counterfactual $\mathbf{x}^\prime=f(\mathbf{z}^\prime)$ where in contrast to Equation~\ref{eq:eccco} $\mathbf{z}^\prime$ denotes a $1$-dimensional array containing a single counterfactual state. That state is initialized by passing the factual $\mathbf{x}$ through the encoder $f^{-1}$ which in our case corresponds to a simple feature transformer, rather than the encoder part of VAE as in REVISE~\citep{joshi2019realistic}. Next, we generate a buffer of $N_{\mathcal{B}}$ conditional samples $\hat{\mathbf{x}}_{\theta}|\mathbf{y}^*$ using SGLD (Equation~\ref{eq:sgld}) and conformalise the model $M_{\theta}$ through Split Conformal Prediction on training data $\mathcal{D}$.
Finally, we search counterfactuals through gradient descent. Let $\mathcal{L}(\mathbf{z}^\prime,\mathbf{y}^*,\hat{\mathbf{x}}_{\theta, t})$ denote our loss function defined in Equation~\ref{eq:eccco}. Then in each iteration, we first randomly draw $n_{\mathcal{B}}$ samples from the buffer $\mathcal{B}$ before updating the counterfactual state $\mathbf{z}^\prime$ by moving in the negative direction of that loss function. The search terminates once the convergence criterium is met or the maximum number of iterations $T$ has been exhausted. Note that the choice of convergence criterium has important implications on the final counterfactual. For more detail on this see Appendix~\ref{app:eccco}).
Figure~\ref{fig:eccco-mnist} presents ECCCos for the MNIST example from Section~\ref{background} for various Black Box models of increasing complexity from left to right: a simple Multi-Layer Perceptron (MLP); an Ensemble of MLPs, each of the same architecture as the single MLP; a Joint Energy Model (JEM) based on the same MLP architecture; and finally, an Ensemble of these JEMs. Since Deep Ensembles have an improved capacity for predictive uncertainty quantification and JEMs are explicitly trained to learn plausible representations of the input data, it is intuitive to see that the plausibility of counterfactuals visibly improves from left to right.
\section{Experiments}\label{conformity}
\subsection{Evaluation Measures}\label{evaluation}
......@@ -297,10 +308,7 @@ As noted by \citet{guidotti2022counterfactual}, these distance-based measures ar
\section{Experiments}
\begin{figure}
\includegraphics[width=\textwidth]{../artifacts/results/images/mnist_eccco.png}
\caption{ECCCos from Black Boxes. Counterfactuals for turning an 8 (eight) into a 3 (three): original image (left); }\label{fig:eccco}
\end{figure}
\begin{itemize}
\item BatchNorm does not seem compatible with JEM
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment