diff --git a/artifacts/results/images/mnist_eccco.png b/artifacts/results/images/mnist_eccco.png index c1a839162f51abfad2296abd0cafaa4d972503df..217094c975213d0afc36922c9b9d0d1653d1b7d4 100644 Binary files a/artifacts/results/images/mnist_eccco.png and b/artifacts/results/images/mnist_eccco.png differ diff --git a/artifacts/results/mnist_vae.jls b/artifacts/results/mnist_vae.jls index 67e2b2de0951f2d914bb4a75088c8ced1e118dbc..e571c3f2ac2db78575bfa61439e4da7b817bb3f7 100644 Binary files a/artifacts/results/mnist_vae.jls and b/artifacts/results/mnist_vae.jls differ diff --git a/artifacts/results/mnist_vae_weak.jls b/artifacts/results/mnist_vae_weak.jls index ecccbf390b1e0764896dc835e21633d9bb8b9c25..5717d5eeb1a7e905354727a7553c701f328a93be 100644 Binary files a/artifacts/results/mnist_vae_weak.jls and b/artifacts/results/mnist_vae_weak.jls differ diff --git a/notebooks/mnist.qmd b/notebooks/mnist.qmd index 4fdd65f540528b208e86fe72ca5607a8df733f33..e8b52e2313bb4eed9641ade3ceef39b8925664fb 100644 --- a/notebooks/mnist.qmd +++ b/notebooks/mnist.qmd @@ -137,7 +137,7 @@ savefig(plt, joinpath(output_images_path, "surrogate_gone_wrong.png")) ```{julia} function pre_process(x; noise::Float32=0.03f0) ϵ = Float32.(randn(size(x)) * noise) - # x = @.(2 * x - 1) + x = @.(2 * x - 1) x += ϵ return x end @@ -145,8 +145,8 @@ end ```{julia} # Hyper: -_retrain = true -_regen = true +_retrain = false +_regen = false # Data: n_obs = 10000 @@ -307,7 +307,7 @@ model_performance = DataFrame() for (mod_name, mod) in model_dict # Test performance: test_data = load_mnist_test() - # test_data.X = pre_process.(test_data.X, noise=0.0f0) + test_data.X = pre_process.(test_data.X, noise=0.0f0) _perf = CounterfactualExplanations.Models.model_evaluation(mod, test_data, measure=collect(values(measure))) _perf = DataFrame([[p] for p in _perf], collect(keys(measure))) _perf.mod_name .= mod_name @@ -321,6 +321,8 @@ model_performance ### Different Models ```{julia} +plt_order = ["MLP", "MLP Ensemble", "JEM", "JEM Ensemble"] + # ECCCo: λ=[0.5,0.1,0.5] temp=0.5 @@ -343,7 +345,7 @@ for (mod_name, mod) in model_dict ) ces[mod_name] = ce end -plt_order = sortperm(collect(keys(ces))) +_plt_order = map(x -> findall(collect(keys(model_dict)) .== x)[1], plt_order) # Plot: p1 = Plots.plot( @@ -366,7 +368,7 @@ for (_name,ce) in ces ) plts = [plts..., plt] end -plts = plts[plt_order] +plts = plts[_plt_order] plts = [p1, plts...] plt = Plots.plot(plts...; size=(img_height*length(plts),img_height), layout=(1,length(plts))) display(plt) diff --git a/paper/paper.pdf b/paper/paper.pdf index 7f8f4cc72ad9d371e1c9c82e43d7b19bde004d3f..2859dae90c8dc3dd0fa0be53ee9a0557c3ac9354 100644 Binary files a/paper/paper.pdf and b/paper/paper.pdf differ diff --git a/paper/paper.tex b/paper/paper.tex index 648a2cc3f08e4f355184397622aeb2beee5c90d3..d362161cbbc956edf863b91e3953d2bc52c3e439 100644 --- a/paper/paper.tex +++ b/paper/paper.tex @@ -238,8 +238,6 @@ where $\hat{\mathbf{x}}_{\theta}$ denotes samples generated using SGLD (Equation The first two terms in Equation~\ref{eq:eccco} correspond to the counterfactual search objective defined in~\citet{wachter2017counterfactual} which merely penalises the distance of counterfactuals from their factual values. The additional two penalties in ECCCo ensure that counterfactuals conform with the model's generative property and lead to minimally uncertain predictions, respectively. The hyperparameters $\lambda_1, ..., \lambda_3$ can be used to balance the different objectives: for example, we may choose to incur larger deviations from the factual in favour of conformity with the model's generative property by choosing lower values of $\lambda_1$ and relatively higher values of $\lambda_2$. Figure~\ref{fig:eccco} illustrates this balancing act for an example involving synthetic data: vector fields indicate the direction of gradients with respect to the different components our proposed objective function (Equation~\ref{eq:eccco}). -The entire procedure for Generating ECCCos is described in Algorithm~\ref{alg:eccco}. For the sake of simplicity and without loss of generality, we limit our attention to generating a single counterfactual $\mathbf{x}^\prime=f(\mathbf{z}^\prime)$ where in contrast to Equation~\ref{eq:eccco} $\mathbf{z}^\prime$ denotes a $1$-dimensional array containing a single counterfactual state. That state is initialized by passing the factual $\mathbf{x}$ through the encoder $f^{-1}$ which in our case corresponds to a simple feature transformer, rather than the encoder part of VAE as in REVISE~\citep{joshi2019realistic}. - \medskip \renewcommand{\algorithmicrequire}{\textbf{Input:}} @@ -254,14 +252,14 @@ The entire procedure for Generating ECCCos is described in Algorithm~\ref{alg:ec \begin{minipage}[c]{0.50\textwidth} \captionof{algorithm}{Generating ECCCos (For more details, see Appendix~\ref{app:eccco})}\label{alg:eccco} \begin{algorithmic}[1] - \Require $\mathbf{x}, \mathbf{y}^*, M_{\theta}, f, \Lambda, \alpha, \mathcal{D}, T, \eta, m, M$ \linebreak where $M_{\theta}(\mathbf{x})\neq\mathbf{y}^*$ + \Require $\mathbf{x}, \mathbf{y}^*, M_{\theta}, f, \Lambda, \alpha, \mathcal{D}, T, \eta, n_{\mathcal{B}}, N_{\mathcal{B}}$ \linebreak where $M_{\theta}(\mathbf{x})\neq\mathbf{y}^*$ \Ensure $\mathbf{x}^\prime$ \State Initialize $\mathbf{z}^\prime \gets f^{-1}(\mathbf{x})$ - \State Generate buffer $\mathcal{B}$ of $M$ conditional samples $\hat{\mathbf{x}}_{\theta}|\mathbf{y}^*$ using SGLD (Equation~\ref{eq:sgld}) + \State Generate buffer $\mathcal{B}$ of $N_{\mathcal{B}}$ conditional samples $\hat{\mathbf{x}}_{\theta}|\mathbf{y}^*$ using SGLD (Equation~\ref{eq:sgld}) \State Run \textit{SCP} for $M_{\theta}$ using $\mathcal{D}$ \State Initialize $t \gets 0$ \While{\textit{not converged} or $t < T$} - \State $\hat{\mathbf{x}}_{\theta, t} \gets \text{rand}(\mathcal{B},m)$ + \State $\hat{\mathbf{x}}_{\theta, t} \gets \text{rand}(\mathcal{B},n_{\mathcal{B}})$ \State $\mathbf{z}^\prime \gets \mathbf{z}^\prime - \eta \nabla_{\mathbf{z}^\prime} \mathcal{L}(\mathbf{z}^\prime,\mathbf{y}^*,\hat{\mathbf{x}}_{\theta, t})$ \State $t \gets t+1$ \EndWhile @@ -271,7 +269,20 @@ The entire procedure for Generating ECCCos is described in Algorithm~\ref{alg:ec \medskip -\section{Evaluation Framework}\label{conformity} +\begin{minipage}[c]{\textwidth} + \includegraphics[width=\textwidth]{../artifacts/results/images/mnist_eccco.png} + \captionof{figure}{Original image (left) and ECCCos for turning an 8 (eight) into a 3 (three) for different Black Boxes from left to right: Multi-Layer Perceptron (MLP), Ensemble of MLPs, Joint Energy Model (JEM), Ensemble of JEMs.}\label{fig:eccco-mnist} +\end{minipage} + +\medskip + +The entire procedure for Generating ECCCos is described in Algorithm~\ref{alg:eccco}. For the sake of simplicity and without loss of generality, we limit our attention to generating a single counterfactual $\mathbf{x}^\prime=f(\mathbf{z}^\prime)$ where in contrast to Equation~\ref{eq:eccco} $\mathbf{z}^\prime$ denotes a $1$-dimensional array containing a single counterfactual state. That state is initialized by passing the factual $\mathbf{x}$ through the encoder $f^{-1}$ which in our case corresponds to a simple feature transformer, rather than the encoder part of VAE as in REVISE~\citep{joshi2019realistic}. Next, we generate a buffer of $N_{\mathcal{B}}$ conditional samples $\hat{\mathbf{x}}_{\theta}|\mathbf{y}^*$ using SGLD (Equation~\ref{eq:sgld}) and conformalise the model $M_{\theta}$ through Split Conformal Prediction on training data $\mathcal{D}$. + +Finally, we search counterfactuals through gradient descent. Let $\mathcal{L}(\mathbf{z}^\prime,\mathbf{y}^*,\hat{\mathbf{x}}_{\theta, t})$ denote our loss function defined in Equation~\ref{eq:eccco}. Then in each iteration, we first randomly draw $n_{\mathcal{B}}$ samples from the buffer $\mathcal{B}$ before updating the counterfactual state $\mathbf{z}^\prime$ by moving in the negative direction of that loss function. The search terminates once the convergence criterium is met or the maximum number of iterations $T$ has been exhausted. Note that the choice of convergence criterium has important implications on the final counterfactual. For more detail on this see Appendix~\ref{app:eccco}). + +Figure~\ref{fig:eccco-mnist} presents ECCCos for the MNIST example from Section~\ref{background} for various Black Box models of increasing complexity from left to right: a simple Multi-Layer Perceptron (MLP); an Ensemble of MLPs, each of the same architecture as the single MLP; a Joint Energy Model (JEM) based on the same MLP architecture; and finally, an Ensemble of these JEMs. Since Deep Ensembles have an improved capacity for predictive uncertainty quantification and JEMs are explicitly trained to learn plausible representations of the input data, it is intuitive to see that the plausibility of counterfactuals visibly improves from left to right. + +\section{Experiments}\label{conformity} \subsection{Evaluation Measures}\label{evaluation} @@ -297,10 +308,7 @@ As noted by \citet{guidotti2022counterfactual}, these distance-based measures ar \section{Experiments} -\begin{figure} - \includegraphics[width=\textwidth]{../artifacts/results/images/mnist_eccco.png} - \caption{ECCCos from Black Boxes. Counterfactuals for turning an 8 (eight) into a 3 (three): original image (left); }\label{fig:eccco} -\end{figure} + \begin{itemize} \item BatchNorm does not seem compatible with JEM