Skip to content
Snippets Groups Projects
Commit 98ab8da5 authored by pat-alt's avatar pat-alt
Browse files

methodology rewritten

parent 1942f141
No related branches found
No related tags found
1 merge request!8373 aries comments
No preview for this file type
......@@ -148,18 +148,36 @@ Our default choice for the $\text{dist}(\cdot)$ function in both cases is the Eu
Given our proposed notion of faithfulness, we now describe \textit{ECCCo}, our proposed framework for generating Energy-Constrained Conformal Counterfactuals. It is based on the premise that counterfactuals should first and foremost be faithful. Plausibility, as a secondary concern, is then still attainable, but only to the degree that the black-box model itself has learned plausible explanations for the underlying data.
We begin by stating our proposed objective function, which involves tailored loss and penalty functions that we will explain in the following. In particular, we extend Equation~\ref{eq:general} as follows:
We begin by substituting the loss function in Equation~\ref{eq:general},
\begin{equation} \label{eq:eccco}
\begin{equation} \label{eq:eccco-start}
\begin{aligned}
\mathbf{Z}^\prime =& \arg \min_{\mathbf{Z}^\prime \in \mathcal{Z}^L} \{ {\text{yloss}(M_{\theta}(f(\mathbf{Z}^\prime)),\mathbf{y}^+)}+ \lambda_{1} {\text{dist}(f(\mathbf{Z}^\prime),\mathbf{x}) } \\
&+ \lambda_2 \mathcal{E}_{\theta}(f(\mathbf{Z}^\prime)) + \lambda_3 \Omega(C_{\theta}(f(\mathbf{Z}^\prime);\alpha)) \}
\mathbf{Z}^\prime =& \arg \min_{\mathbf{Z}^\prime \in \mathcal{Z}^L} \{ {L_{\text{JEM}}(f(\mathbf{Z}^\prime);M_{\theta},\mathbf{y}^+)}+ \lambda {\text{cost}(f(\mathbf{Z}^\prime)) } \}
\end{aligned}
\end{equation}
The first penalty term involving $\lambda_1$ induces proximity like in~\citet{wachter2017counterfactual}. Our default choice for $\text{dist}(\cdot)$ is the L1 Norm due to its sparsity-inducing properties. The second penalty term involving $\lambda_2$ induces faithfulness by constraining the energy of the generated counterfactual. The third and final penalty term involving $\lambda_3$ ensures that the generated counterfactual is associated with low predictive uncertainty.
where $L_{\text{JEM}}(f(\mathbf{Z}^\prime);M_{\theta},\mathbf{y}^+)$ is a hybrid loss function used in joint-energy modelling evaluated at a given counterfactual state for a given model and target outcome:
\begin{equation}
\begin{aligned}
L_{\text{JEM}}(f(\mathbf{Z}^\prime); \cdot) = L_{\text{clf}}(f(\mathbf{Z}^\prime); \cdot) + L_{\text{gen}}(f(\mathbf{Z}^\prime); \cdot)
\end{aligned}
\end{equation}
The first term, $L_{\text{clf}}$, is any standard classification loss function such as cross-entropy loss. The second term, $L_{\text{gen}}$, is used to measure loss with respect to the generative task\footnote{In practice, regularization loss is typically also added. We follow this convention but have omitted the term here for simplicity.}. In the context of joint-energy training, $L_{\text{gen}}$ induces changes in model parameters $\theta$ that decrease the energy of observed samples and increase the energy of samples generated through SGLD~\citep{du2019implicit}.
The key observation in our context is that we can rely solely on decreasing the energy of the counterfactual itself. This is sufficient to capture the generative property of the underlying model since it is implicitly captured by its parameters $\theta$. Importantly, this means that we do not need to generate conditional samples through SGLD during our counterfactual search at all as we explain in Appendix~\ref{app:eccco}.
This observation leads to the following simple objective function for \textit{ECCCo}:
\begin{equation} \label{eq:eccco}
\begin{aligned}
\mathbf{Z}^\prime =& \arg \min_{\mathbf{Z}^\prime \in \mathcal{Z}^L} \{ {L_{\text{clf}}(f(\mathbf{Z}^\prime);M_{\theta},\mathbf{y}^+)}+ \lambda_1 {\text{cost}(f(\mathbf{Z}^\prime)) } \\
&+ \lambda_2 \mathcal{E}_{\theta}(f(\mathbf{Z}^\prime)|\mathbf{y}^+) + \lambda_3 \Omega(C_{\theta}(f(\mathbf{Z}^\prime);\alpha)) \}
\end{aligned}
\end{equation}
The objective defined in Equation~\ref{eq:eccco} is closely related to the training objective of energy-based models, which involves a generative loss component that measures the difference in energies of observed and generated samples. The key observation in our context is that we can rely solely on the energy of the counterfactual itself. By simply constraining $\mathcal{E}_{\theta}(f(\mathbf{Z}^\prime))$ we can induce counterfactuals that faithfully represent the generative property of the underlying model, which is implicitly captured by its parameters $\theta$. Importantly, this means that we do not need to generate conditional samples through SGLD at all. For a detailed discussion of this please see Appendix~\ref{app:eccco}.
The first penalty term involving $\lambda_1$ induces closeness like in~\citet{wachter2017counterfactual}. The second penalty term involving $\lambda_2$ induces faithfulness by constraining the energy of the generated counterfactual. The third and final penalty term involving $\lambda_3$ ensures that the generated counterfactual is associated with low predictive uncertainty.
\begin{figure}
\centering
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment