Skip to content
Snippets Groups Projects
Commit b8cc35eb authored by pat-alt's avatar pat-alt
Browse files

added energy delta

parent 715bbf64
No related branches found
No related tags found
1 merge request!8373 aries comments
......@@ -2,10 +2,11 @@
1. Applied to additional commonly used tabular real-world datasets
2. Energy delta
3. Better results also for image data
4. Counterfactual explanations do not scale well to high-dimensional input data
1. Better results, in particular for image data
2. No longer biased (addressing reviewer concern)
3. Counterfactual explanations do not scale well to high-dimensional input data
1. We have added native support for multi-processing and multi-threading
2. We have run more extensive experiments including fine-tuning hyperparameter choices
5. We have revisited the mathematical notation.
6. We have moved the introduction of conformal prediction forward and added more detail in line with reviewer feedback.
7. We have extended the limitations section.
\ No newline at end of file
4. We have revisited the mathematical notation.
5. We have moved the introduction of conformal prediction forward and added more detail in line with reviewer feedback.
6. We have extended the limitations section.
\ No newline at end of file
No preview for this file type
......@@ -134,15 +134,15 @@ The parallels between our definitions of plausibility and faithfulness imply tha
where $\mathbf{x}^{\prime}$ denotes the counterfactual and $\mathbf{X}_{\mathbf{y}^+}$ is a subsample of the training data in the target class $\mathbf{y}^+$. By averaging over multiple samples in this manner, we avoid the risk that the nearest neighbour of $\mathbf{x}^{\prime}$ itself is not plausible according to Definition~\ref{def:plausible} (e.g an outlier).
Equation~\ref{eq:impl} gives rise to a similar evaluation metric for unfaithfulness. We merely swap out the subsample of individuals in the target class for a subset $\widehat{\mathbf{X}}_{\mathbf{y}^+}$ of the generated conditional samples:
Equation~\ref{eq:impl} gives rise to a similar evaluation metric for unfaithfulness. We swap out the subsample of observed individuals in the target class for the set of samples generated through SGLD ($\widehat{\mathbf{X}}_{\mathbf{y}^+}$):
\begin{equation}\label{eq:faith}
\begin{aligned}
\text{unfaith}(\mathbf{x}^{\prime},\widehat{\mathbf{X}}_{\mathbf{y}^+}) = \frac{1}{n_E} \sum_{\mathbf{x} \in \widehat{\mathbf{X}}_{\mathbf{y}^+}} \text{dist}(\mathbf{x}^{\prime},\mathbf{x})
\text{unfaith}(\mathbf{x}^{\prime},\widehat{\mathbf{X}}_{\theta,\mathbf{y}^+}) = \frac{1}{\lvert \widehat{\mathbf{X}}_{\theta,\mathbf{y}^+} \rvert} \sum_{\mathbf{x} \in \widehat{\mathbf{X}}_{\theta,\mathbf{y}^+}} \text{dist}(\mathbf{x}^{\prime},\mathbf{x})
\end{aligned}
\end{equation}
Specifically, we form this subset based on the $n_E$ generated samples with the lowest energy.
Our default choice for the $\text{dist}(\cdot)$ in both cases is the L1 Norm.
\section{Energy-Constrained Conformal Counterfactuals}\label{meth}
......@@ -153,11 +153,19 @@ We begin by stating our proposed objective function, which involves tailored los
\begin{equation} \label{eq:eccco}
\begin{aligned}
\mathbf{Z}^\prime &= \arg \min_{\mathbf{Z}^\prime \in \mathcal{Z}^L} \{ {\text{yloss}(M_{\theta}(f(\mathbf{Z}^\prime)),\mathbf{y}^+)}+ \lambda_{1} {\text{dist}(f(\mathbf{Z}^\prime),\mathbf{x}) } \\
&+ \lambda_2 \text{unfaith}(f(\mathbf{Z}^\prime),\widehat{\mathbf{X}}_{\mathbf{y}^+}) + \lambda_3 \Omega(C_{\theta}(f(\mathbf{Z}^\prime);\alpha)) \}
&+ \lambda_2 \Delta\mathcal{E}(\mathbf{Z}^\prime,\widehat{\mathbf{X}}_{\theta,\mathbf{y}^+}) + \lambda_3 \Omega(C_{\theta}(f(\mathbf{Z}^\prime);\alpha)) \}
\end{aligned}
\end{equation}
The first penalty term involving $\lambda_1$ induces proximity like in~\citet{wachter2017counterfactual}. Our default choice for $\text{dist}(\cdot)$ is the L1 Norm due to its sparsity-inducing properties. The second penalty term involving $\lambda_2$ induces faithfulness by constraining the energy of the generated counterfactual where $\text{unfaith}(\cdot)$ corresponds to the metric defined in Equation~\ref{eq:faith}. The third and final penalty term involving $\lambda_3$ ensures that the generated counterfactual is associated with low predictive uncertainty.
The first penalty term involving $\lambda_1$ induces proximity like in~\citet{wachter2017counterfactual}. Our default choice for $\text{dist}(\cdot)$ is the L1 Norm due to its sparsity-inducing properties. The second penalty term involving $\lambda_2$ induces faithfulness by constraining the energy of the generated counterfactual where we have:
\begin{equation} \label{eq:energy-delta}
\begin{aligned}
\Delta\mathcal{E}&=\mathcal{E}(f(\mathbf{Z}^\prime)|\mathbf{y}^+)-\mathcal{E}(x|\mathbf{y}^+) &&& x \sim \widehat{\mathbf{X}}_{\theta,\mathbf{y}^+}
\end{aligned}
\end{equation}
In particular, this penalty ensures that the energy of the generated counterfactual is in balance with the energy of the generated conditional samples ($\widehat{\mathbf{X}}_{\theta,\mathbf{y}^+}$). The third and final penalty term involving $\lambda_3$ ensures that the generated counterfactual is associated with low predictive uncertainty.
\begin{figure}
\centering
......@@ -172,11 +180,11 @@ The first penalty term involving $\lambda_1$ induces proximity like in~\citet{wa
\Ensure $\mathbf{x}^\prime$
\State Initialize $\mathbf{z}^\prime \gets f^{-1}(\mathbf{x})$ \Comment{Map to counterfactual state space.}
\State Generate $\left\{\hat{\mathbf{x}}_{\theta,\mathbf{y}^+}\right\}_{n_{\mathcal{B}}} \gets p_{\theta}(\mathbf{x}_{\mathbf{y}^+})$ \Comment{Generate $n_{\mathcal{B}}$ samples using SGLD (Equation~\ref{eq:sgld}).}
\State Store $\widehat{\mathbf{X}}_{\mathbf{y}^+} \gets \left\{\hat{\mathbf{x}}_{\theta,\mathbf{y}^+}\right\}_{n_{\mathcal{B}}}$ \Comment{Choose $n_E$ lowest-energy samples.}
\State Store $\widehat{\mathbf{X}}_{\theta,\mathbf{y}^+} \gets \left\{\hat{\mathbf{x}}_{\theta,\mathbf{y}^+}\right\}_{n_{\mathcal{B}}}$ \Comment{Choose $n_E$ lowest-energy samples.}
\State Run \textit{SCP} for $M_{\theta}$ using $\mathcal{D}$ \Comment{Calibrate model through split conformal prediction.}
\State Initialize $t \gets 0$
\While{\textit{not converged} or $t < T$} \Comment{For convergence conditions see Appendix~\ref{app:eccco}.}
\State $\mathbf{z}^\prime \gets \mathbf{z}^\prime - \eta \nabla_{\mathbf{z}^\prime} \mathcal{L}(\mathbf{z}^\prime,\mathbf{y}^+,\widehat{\mathbf{X}}_{\mathbf{y}^+}; \Lambda, \alpha)$ \Comment{Take gradient step of size $\eta$.}
\State $\mathbf{z}^\prime \gets \mathbf{z}^\prime - \eta \nabla_{\mathbf{z}^\prime} \mathcal{L}(\mathbf{z}^\prime,\mathbf{y}^+,\widehat{\mathbf{X}}_{\theta,\mathbf{y}^+}; \Lambda, \alpha)$ \Comment{Take gradient step of size $\eta$.}
\State $t \gets t+1$
\EndWhile
\State $\mathbf{x}^\prime \gets f(\mathbf{z}^\prime)$ \Comment{Map back to feature space.}
......@@ -187,7 +195,7 @@ Figure~\ref{fig:poc} illustrates how the different components in Equation~\ref{e
While \textit{Wachter} generates a valid counterfactual, it ends up close to the original starting point consistent with its objective. \textit{ECCCo (no EBM)} pushes the counterfactual further into the target domain to minimize predictive uncertainty, but the outcome is still not plausible. The counterfactual produced by \textit{ECCCo (no CP)} is attracted by the generated samples shown in bright yellow. Since the \textit{JEM} has learned the conditional input distribution reasonably well in this case, the counterfactuals are both faithful and plausible. Finally, the outcome for \textit{ECCCo} looks similar, but the additional smooth set size penalty leads to somewhat faster convergence.
Algorithm~\ref{alg:eccco} describes how exactly \textit{ECCCo} works. For the sake of simplicity and without loss of generality, we limit our attention to generating a single counterfactual $\mathbf{x}^\prime=f(\mathbf{z}^\prime)$. The counterfactual state $\mathbf{z}^\prime$ is initialized by passing the factual $\mathbf{x}$ through a simple feature transformer $f^{-1}$. Next, we generate $n_{\mathcal{B}}$ conditional samples $\hat{\mathbf{x}}_{\theta,\mathbf{y}^+}$ using SGLD (Equation~\ref{eq:sgld}) and store the $n_E$ instances with the lowest energy. We then calibrate the model $M_{\theta}$ through split conformal prediction. Finally, we search counterfactuals through gradient descent where $\mathcal{L}(\mathbf{z}^\prime,\mathbf{y}^+,\widehat{\mathbf{X}}_{\mathbf{y}^+}; \Lambda, \alpha)$ denotes our loss function defined in Equation~\ref{eq:eccco}. The search terminates once the convergence criterium is met or the maximum number of iterations $T$ has been exhausted. Note that the choice of convergence criterium has important implications on the final counterfactual which we explain in Appendix~\ref{app:eccco}.
Algorithm~\ref{alg:eccco} describes how exactly \textit{ECCCo} works. For the sake of simplicity and without loss of generality, we limit our attention to generating a single counterfactual $\mathbf{x}^\prime=f(\mathbf{z}^\prime)$. The counterfactual state $\mathbf{z}^\prime$ is initialized by passing the factual $\mathbf{x}$ through a simple feature transformer $f^{-1}$. Next, we generate $n_{\mathcal{B}}$ conditional samples $\hat{\mathbf{x}}_{\theta,\mathbf{y}^+}$ using SGLD (Equation~\ref{eq:sgld}) and store the $n_E$ instances with the lowest energy. We then calibrate the model $M_{\theta}$ through split conformal prediction. Finally, we search counterfactuals through gradient descent where $\mathcal{L}(\mathbf{z}^\prime,\mathbf{y}^+,\widehat{\mathbf{X}}_{\theta,\mathbf{y}^+}; \Lambda, \alpha)$ denotes our loss function defined in Equation~\ref{eq:eccco}. The search terminates once the convergence criterium is met or the maximum number of iterations $T$ has been exhausted. Note that the choice of convergence criterium has important implications on the final counterfactual which we explain in Appendix~\ref{app:eccco}.
\section{Empirical Analysis}\label{emp}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment