@@ -134,15 +134,15 @@ The parallels between our definitions of plausibility and faithfulness imply tha
where $\mathbf{x}^{\prime}$ denotes the counterfactual and $\mathbf{X}_{\mathbf{y}^+}$ is a subsample of the training data in the target class $\mathbf{y}^+$. By averaging over multiple samples in this manner, we avoid the risk that the nearest neighbour of $\mathbf{x}^{\prime}$ itself is not plausible according to Definition~\ref{def:plausible} (e.g an outlier).
Equation~\ref{eq:impl} gives rise to a similar evaluation metric for unfaithfulness. We merely swap out the subsample of individuals in the target class for a subset $\widehat{\mathbf{X}}_{\mathbf{y}^+}$ of the generated conditional samples:
Equation~\ref{eq:impl} gives rise to a similar evaluation metric for unfaithfulness. We swap out the subsample of observed individuals in the target class for the set of samples generated through SGLD ($\widehat{\mathbf{X}}_{\mathbf{y}^+}$):
The first penalty term involving $\lambda_1$ induces proximity like in~\citet{wachter2017counterfactual}. Our default choice for $\text{dist}(\cdot)$ is the L1 Norm due to its sparsity-inducing properties. The second penalty term involving $\lambda_2$ induces faithfulness by constraining the energy of the generated counterfactual where $\text{unfaith}(\cdot)$ corresponds to the metric defined in Equation~\ref{eq:faith}. The third and final penalty term involving $\lambda_3$ ensures that the generated counterfactual is associated with low predictive uncertainty.
The first penalty term involving $\lambda_1$ induces proximity like in~\citet{wachter2017counterfactual}. Our default choice for $\text{dist}(\cdot)$ is the L1 Norm due to its sparsity-inducing properties. The second penalty term involving $\lambda_2$ induces faithfulness by constraining the energy of the generated counterfactual where we have:
\begin{equation}\label{eq:energy-delta}
\begin{aligned}
\Delta\mathcal{E}&=\mathcal{E}(f(\mathbf{Z}^\prime)|\mathbf{y}^+)-\mathcal{E}(x|\mathbf{y}^+) &&& x \sim\widehat{\mathbf{X}}_{\theta,\mathbf{y}^+}
\end{aligned}
\end{equation}
In particular, this penalty ensures that the energy of the generated counterfactual is in balance with the energy of the generated conditional samples ($\widehat{\mathbf{X}}_{\theta,\mathbf{y}^+}$). The third and final penalty term involving $\lambda_3$ ensures that the generated counterfactual is associated with low predictive uncertainty.
\begin{figure}
\centering
...
...
@@ -172,11 +180,11 @@ The first penalty term involving $\lambda_1$ induces proximity like in~\citet{wa
\Ensure$\mathbf{x}^\prime$
\State Initialize $\mathbf{z}^\prime\gets f^{-1}(\mathbf{x})$\Comment{Map to counterfactual state space.}
\State Generate $\left\{\hat{\mathbf{x}}_{\theta,\mathbf{y}^+}\right\}_{n_{\mathcal{B}}}\gets p_{\theta}(\mathbf{x}_{\mathbf{y}^+})$\Comment{Generate $n_{\mathcal{B}}$ samples using SGLD (Equation~\ref{eq:sgld}).}
\State Store $\widehat{\mathbf{X}}_{\mathbf{y}^+}\gets\left\{\hat{\mathbf{x}}_{\theta,\mathbf{y}^+}\right\}_{n_{\mathcal{B}}}$\Comment{Choose $n_E$ lowest-energy samples.}
\State Store $\widehat{\mathbf{X}}_{\theta,\mathbf{y}^+}\gets\left\{\hat{\mathbf{x}}_{\theta,\mathbf{y}^+}\right\}_{n_{\mathcal{B}}}$\Comment{Choose $n_E$ lowest-energy samples.}
\State Run \textit{SCP} for $M_{\theta}$ using $\mathcal{D}$\Comment{Calibrate model through split conformal prediction.}
\State Initialize $t \gets0$
\While{\textit{not converged} or $t < T$}\Comment{For convergence conditions see Appendix~\ref{app:eccco}.}
\State$\mathbf{z}^\prime\gets\mathbf{z}^\prime-\eta\nabla_{\mathbf{z}^\prime}\mathcal{L}(\mathbf{z}^\prime,\mathbf{y}^+,\widehat{\mathbf{X}}_{\mathbf{y}^+}; \Lambda, \alpha)$\Comment{Take gradient step of size $\eta$.}
\State$\mathbf{z}^\prime\gets\mathbf{z}^\prime-\eta\nabla_{\mathbf{z}^\prime}\mathcal{L}(\mathbf{z}^\prime,\mathbf{y}^+,\widehat{\mathbf{X}}_{\theta,\mathbf{y}^+}; \Lambda, \alpha)$\Comment{Take gradient step of size $\eta$.}
\State$t \gets t+1$
\EndWhile
\State$\mathbf{x}^\prime\gets f(\mathbf{z}^\prime)$\Comment{Map back to feature space.}
...
...
@@ -187,7 +195,7 @@ Figure~\ref{fig:poc} illustrates how the different components in Equation~\ref{e
While \textit{Wachter} generates a valid counterfactual, it ends up close to the original starting point consistent with its objective. \textit{ECCCo (no EBM)} pushes the counterfactual further into the target domain to minimize predictive uncertainty, but the outcome is still not plausible. The counterfactual produced by \textit{ECCCo (no CP)} is attracted by the generated samples shown in bright yellow. Since the \textit{JEM} has learned the conditional input distribution reasonably well in this case, the counterfactuals are both faithful and plausible. Finally, the outcome for \textit{ECCCo} looks similar, but the additional smooth set size penalty leads to somewhat faster convergence.
Algorithm~\ref{alg:eccco} describes how exactly \textit{ECCCo} works. For the sake of simplicity and without loss of generality, we limit our attention to generating a single counterfactual $\mathbf{x}^\prime=f(\mathbf{z}^\prime)$. The counterfactual state $\mathbf{z}^\prime$ is initialized by passing the factual $\mathbf{x}$ through a simple feature transformer $f^{-1}$. Next, we generate $n_{\mathcal{B}}$ conditional samples $\hat{\mathbf{x}}_{\theta,\mathbf{y}^+}$ using SGLD (Equation~\ref{eq:sgld}) and store the $n_E$ instances with the lowest energy. We then calibrate the model $M_{\theta}$ through split conformal prediction. Finally, we search counterfactuals through gradient descent where $\mathcal{L}(\mathbf{z}^\prime,\mathbf{y}^+,\widehat{\mathbf{X}}_{\mathbf{y}^+}; \Lambda, \alpha)$ denotes our loss function defined in Equation~\ref{eq:eccco}. The search terminates once the convergence criterium is met or the maximum number of iterations $T$ has been exhausted. Note that the choice of convergence criterium has important implications on the final counterfactual which we explain in Appendix~\ref{app:eccco}.
Algorithm~\ref{alg:eccco} describes how exactly \textit{ECCCo} works. For the sake of simplicity and without loss of generality, we limit our attention to generating a single counterfactual $\mathbf{x}^\prime=f(\mathbf{z}^\prime)$. The counterfactual state $\mathbf{z}^\prime$ is initialized by passing the factual $\mathbf{x}$ through a simple feature transformer $f^{-1}$. Next, we generate $n_{\mathcal{B}}$ conditional samples $\hat{\mathbf{x}}_{\theta,\mathbf{y}^+}$ using SGLD (Equation~\ref{eq:sgld}) and store the $n_E$ instances with the lowest energy. We then calibrate the model $M_{\theta}$ through split conformal prediction. Finally, we search counterfactuals through gradient descent where $\mathcal{L}(\mathbf{z}^\prime,\mathbf{y}^+,\widehat{\mathbf{X}}_{\theta,\mathbf{y}^+}; \Lambda, \alpha)$ denotes our loss function defined in Equation~\ref{eq:eccco}. The search terminates once the convergence criterium is met or the maximum number of iterations $T$ has been exhausted. Note that the choice of convergence criterium has important implications on the final counterfactual which we explain in Appendix~\ref{app:eccco}.