Skip to content
Snippets Groups Projects
Commit e9f9590a authored by pat-alt's avatar pat-alt
Browse files

more work on the appendix

parent 42dabcdc
No related branches found
No related tags found
1 merge request!8373 aries comments
No preview for this file type
......@@ -6,25 +6,47 @@ The following appendices provide additional details that are relevant to the pap
\subsection{Energy-Based Modelling}\label{app:jem}
Since we were not able to identify any existing open-source software for Energy-Based Modelling that would be flexible enough to cater to our needs, we have developed a \texttt{Julia} package from scratch. The package has been open-sourced, but to avoid compromising the double-blind review process, we refrain from providing more information at this stage. In our development we have heavily drawn on the existing literature:~\citet{du2020implicit} describe best practices for using EBM for generative modelling;~\citet{grathwohl2020your} explain how EBM can be used to train classifiers jointly for the discriminative and generative tasks. We have used the same package for training and inference, but there are some important differences between the two cases that are worth highlighting here.
Since we were not able to identify any existing open-source software for Energy-Based Modelling that would be flexible enough to cater to our needs, we have developed a \texttt{Julia} package from scratch. The package has been open-sourced, but to avoid compromising the double-blind review process, we refrain from providing more information at this stage. In our development we have heavily drawn on the existing literature:~\citet{du2019implicit} describe best practices for using EBM for generative modelling;~\citet{grathwohl2020your} explain how EBM can be used to train classifiers jointly for the discriminative and generative tasks. We have used the same package for training and inference, but there are some important differences between the two cases that are worth highlighting here.
\subsubsection{Training: Joint Energy Models}
To train our Joint Energy Models we broadly follow the approach outlined in~\citet{grathwohl2020your}. These models are trained to optimize a hybrid objective that involves a standard classification loss component $L_{\text{clf}}(\theta)=-\log p_{\theta}(\mathbf{y}|\mathbf{x})$ (e.g. cross-entropy loss) as well as a generative loss component $L_{\text{gen}}(\theta)=-\log p_{\theta}(\mathbf{x})$.
To train our Joint Energy Models we broadly follow the approach outlined in~\citet{grathwohl2020your}. Formally, JEMs are defined by the following joint distribution:
To draw samples from $p_{\theta}(\mathbf{x})$, we rely exclusively on the conditional sampling approach described in~\citet{grathwohl2020your} for both training and inference: we first draw $\mathbf{y}\sim p(\mathbf{y})$ and then sample $\mathbf{x} \sim p_{\theta}(\mathbf{x}|\mathbf{y})$~\citep{grathwohl2020your} via Equation~\ref{eq:sgld} with energy $\mathcal{E}(\mathbf{x}|\mathbf{y})=\mu_{\theta}(\mathbf{x})[\mathbf{y}]$ where $\mu_{\theta}: \mathcal{X} \mapsto \mathbb{R}^K$ returns the linear predictions (logits) of our classifier $M_{\theta}$. While our package also supports unconditional sampling, we found conditional sampling to work well. It is also well aligned with CE, since in this context we are interested in conditioning on the target class.
\begin{equation}
\begin{aligned}
\log p\_{\theta}(\mathbf{x},\mathbf{y}) &= \log p_{\theta}(\mathbf{y}|\mathbf{x}) + \log p_{\theta}(\mathbf{x})
\end{aligned}
\end{equation}
Training therefore involves a standard classification loss component $L_{\text{clf}}(\theta)=-\log p_{\theta}(\mathbf{y}|\mathbf{x})$ (e.g. cross-entropy loss) as well as a generative loss component $L_{\text{gen}}(\theta)=-\log p_{\theta}(\mathbf{x})$. Analogous to how we defined the conditional distribution over inputs in Definition~\ref{def:faithful}, $p_{\theta}(\mathbf{x})$ denotes the unconditional distribution over inputs. The model gradient of this component of the loss function can be expressed as follows:
\begin{equation}\label{eq:gen-true}
\begin{aligned}
\nabla_{\theta}L_{\text{gen}}(\theta)&=-\nabla_{\theta}\log p\_{\theta}(\mathbf{x})=-\left(\mathbb{E}_{p(\mathbf{x})} \left\{ \nabla_{\theta} \mathcal{E}_{\theta}(\mathbf{x}) \right\} - \mathbb{E}_{p_{\theta}(\mathbf{x})} \left\{ \nabla_{\theta} \mathcal{E}_{\theta}(\mathbf{x}) \right\} \right)
\end{aligned}
\end{equation}
To draw samples from $p_{\theta}(\mathbf{x})$, we rely exclusively on the conditional sampling approach described in~\citet{grathwohl2020your} for both training and inference: we first draw $\mathbf{y}\sim p(\mathbf{y})$ and then sample $\mathbf{x} \sim p_{\theta}(\mathbf{x}|\mathbf{y})$~\citep{grathwohl2020your} via Equation~\ref{eq:sgld} with energy $\mathcal{E}_{\theta}(\mathbf{x}|\mathbf{y})=\mu_{\theta}(\mathbf{x})[\mathbf{y}]$ where $\mu_{\theta}: \mathcal{X} \mapsto \mathbb{R}^K$ returns the linear predictions (logits) of our classifier $M_{\theta}$. While our package also supports unconditional sampling, we found conditional sampling to work well. It is also well aligned with CE, since in this context we are interested in conditioning on the target class.
As mentioned in the body of the paper, we rely on a biased sampler involving separately specified values for the step size $\epsilon$ and the standard deviation $\sigma$ of the stochastic term involving $\mathbf{r}$. Formally, our biased sampler performs updates as follows:
\begin{equation}\label{eq:biased-sgld}
\begin{aligned}
\hat{\mathbf{x}}_{j+1} &\leftarrow \hat{\mathbf{x}}_j - \frac{\epsilon}{2} \mathcal{E}(\hat{\mathbf{x}}_j|\mathbf{y}^+) + \sigma \mathbf{r}_j, && j=1,...,J
\hat{\mathbf{x}}_{j+1} &\leftarrow \hat{\mathbf{x}}_j - \frac{\epsilon}{2} \mathcal{E}_{\theta}(\hat{\mathbf{x}}_j|\mathbf{y}^+) + \sigma \mathbf{r}_j, && j=1,...,J
\end{aligned}
\end{equation}
Consistent with~\citet{grathwohl2020your}, we have specified $\epsilon=2$ and $\sigma=0.01$ as the default values for all of our experiments. The number of total SGLD steps $J$ varies by dataset (Table~\ref{tab:ebmparams}). Following best practices, we initialize $\mathbf{x}_0$ randomly in 5\% of all cases and sample from a buffer in all other cases. The buffer itself is randomly initialised and gradually grows to a maximum of 10,000 samples during training as $\hat{\mathbf{x}}_{J}$ is stored in each epoch~\citep{du2019implicit,grathwohl2020your}.
It is important to realise that sampling is done during each training epoch, which makes training Joint Energy Models significantly harder than conventional neural classifiers. In each epoch the generated (batch of) sample(s) $\hat{\mathbf{x}}_{J}$ is used as part of the generative loss component, which compares its energy to that of observed samples $\mathbf{x}$: $L_{\text{gen}}(\theta)=\mu_{\theta}(\mathbf{x})[\mathbf{y}]-\mu_{\theta}(\hat{\mathbf{x}}_{J})[\mathbf{y}]$. Our full training objective can be summarized as follows,
It is important to realise that sampling is done during each training epoch, which makes training Joint Energy Models significantly harder than conventional neural classifiers. In each epoch the generated (batch of) sample(s) $\hat{\mathbf{x}}_{J}$ is used as part of the generative loss component, which compares its energy to that of observed samples $\mathbf{x}$:
\begin{equation}\label{eq:gen-loss}
\begin{aligned}
L_{\text{gen}}(\theta)&\approx\mu_{\theta}(\mathbf{x})[\mathbf{y}]-\mu_{\theta}(\hat{\mathbf{x}}_{J})[\mathbf{y}]
\end{aligned}
\end{equation}
Our full training objective can be summarized as follows,
\begin{equation}\label{eq:jem-loss}
\begin{aligned}
......@@ -32,7 +54,7 @@ It is important to realise that sampling is done during each training epoch, whi
\end{aligned}
\end{equation}
where $L_{\text{reg}}(\theta)$ is a Ridge penalty (L2 norm) that regularises energy magnitudes for both observed and generated samples~\citep{du2020implicit}. We have used varying degrees of regularization depending on the dataset ($\lambda$ in Table~\ref{tab:ebmparams}).
where $L_{\text{reg}}(\theta)$ is a Ridge penalty (L2 norm) that regularises energy magnitudes for both observed and generated samples~\citep{du2019implicit}. We have used varying degrees of regularization depending on the dataset ($\lambda$ in Table~\ref{tab:ebmparams}).
Contrary to existing work, we have not typically used the entire minibatch of training data for the generative loss component but found that using a subset of the minibatch was often sufficient in attaining decent generative performance (Table~\ref{tab:ebmparams}). This has helped to reduce the computational burden for our models, which should make it easier for others to reproduce our findings. Figures~\ref{fig:mnist-gen} and~\ref{fig:moons-gen} show generated samples for our \textit{MNIST} and \textit{Moons} data, to provide a sense of their generative property.
......@@ -98,7 +120,7 @@ The counterfactual search objective for \textit{ECCCo} was introduced in Equatio
\begin{equation} \label{eq:eccco-app}
\begin{aligned}
\mathbf{Z}^\prime &= \arg \min_{\mathbf{Z}^\prime \in \mathcal{Z}^L} \{ {\text{yloss}(M_{\theta}(f(\mathbf{Z}^\prime)),\mathbf{y}^+)}+ \lambda_{1} {\text{dist}(f(\mathbf{Z}^\prime),\mathbf{x}) } \\
&+ \lambda_2 \Delta\mathcal{E}(\mathbf{Z}^\prime,\widehat{\mathbf{X}}_{\theta,\mathbf{y}^+}) + \lambda_3 \Omega(C_{\theta}(f(\mathbf{Z}^\prime);\alpha)) \}
&+ \lambda_2 \Delta\mathcal{E}_{\theta}(\mathbf{Z}^\prime,\widehat{\mathbf{X}}_{\theta,\mathbf{y}^+}) + \lambda_3 \Omega(C_{\theta}(f(\mathbf{Z}^\prime);\alpha)) \}
\end{aligned}
\end{equation}
......@@ -110,18 +132,25 @@ We can make the connection to energy-based modeling more explicit by restating t
\end{aligned}
\end{equation}
since $\Delta\mathcal{E}(\cdot)$ is equivalent to the generative loss function $L_{\text{gen}}(\cdot)$. In fact, this is also true for $\lambda L_{\text{reg}}(\theta)\ne0$ since we use the Ridge penalty $L_{\text{reg}}(\theta)$ in the counterfactual search just like we do in joint-energy training. This detail was omitted from the body of the paper for the sake of simplicity.
since $\Delta\mathcal{E}_{\theta}(\cdot)$ is equivalent to the generative loss function $L_{\text{gen}}(\cdot)$. In fact, this is also true for $\lambda L_{\text{reg}}(\theta)\ne0$ since we use the Ridge penalty $L_{\text{reg}}(\theta)$ in the counterfactual search just like we do in joint-energy training. This detail was omitted from the body of the paper for the sake of simplicity.
Aside from the additional penalties in Equation~\ref{eq:eccco-app}, the only key difference between our counterfactual search objective and the joint-energy training objective is the parameter that is being optimized. In joint-energy training we optimize the objective with respect to the network weights $\theta$:
Aside from the additional penalties in Equation~\ref{eq:eccco-app}, the only key difference between our counterfactual search objective and the joint-energy training objective is the parameter that is being optimized. In joint-energy training we optimize the objective with respect to the network weights $\theta$. Recall that $\mathcal{E}_{\theta}(\mathbf{x}|\mathbf{y})=\mu_{\theta}(\mathbf{x})[\mathbf{y}]$. Then the partial gradient with respect to the generative loss component can be expressed as follows:
\begin{equation}\label{eq:jem-grad}
\begin{aligned}
\nabla_{\theta} L_{\text{JEM}}(\theta) &=\nabla_{\theta} L_{\text{clf}}(\theta) + \nabla_{\theta}L_{\text{gen}}(\theta) + \lambda \nabla_{\theta} L_{\text{reg}}(\theta)
\nabla_{\theta}L_{\text{gen}}(\theta) &= \nabla_{\theta}\mu_{\theta}(\mathbf{x})[\mathbf{y}]- \nabla_{\theta}\mu_{\theta}(\hat{\mathbf{x}}_{J})[\mathbf{y}]
\end{aligned}
\end{equation}
During the counterfactual search we take the network parameters as fixed and instead optimize with respect to counterfactual itself:
During the counterfactual search, we take the network parameters as fixed and instead optimize with respect to the counterfactual itself,
\begin{equation}\label{eq:ce-grad}
\begin{aligned}
\nabla_{\mathbf{x}}L_{\text{gen}}(\theta) &= \nabla_{\mathbf{x}}\mu_{\theta}(\mathbf{x})[\mathbf{y}]- \nabla_{\mathbf{x}}\mu_{\theta}(\hat{\mathbf{x}}_{J})[\mathbf{y}]
\end{aligned}
\end{equation}
where we omit the notion of a latent search space to make the comparison easier. Intuitively, taking iterative gradient steps according to Equation~\ref{eq:ce-grad} has the effect of decreasing the energy of the counterfactual until it is in balance with the energy of conditional samples generated through SGLD.
\subsubsection{A Note on Convergence}
......
......@@ -87,11 +87,11 @@ To assess counterfactuals with respect to Definition~\ref{def:faithful}, we need
\begin{equation}\label{eq:sgld}
\begin{aligned}
\mathbf{x}_{j+1} &\leftarrow \mathbf{x}_j - \frac{\epsilon_j^2}{2} \mathcal{E}(\mathbf{x}_j|\mathbf{y}^+) + \epsilon_j \mathbf{r}_j, && j=1,...,J
\mathbf{x}_{j+1} &\leftarrow \mathbf{x}_j - \frac{\epsilon_j^2}{2} \mathcal{E}_{\theta}(\mathbf{x}_j|\mathbf{y}^+) + \epsilon_j \mathbf{r}_j, && j=1,...,J
\end{aligned}
\end{equation}
where $\mathbf{r}_j \sim \mathcal{N}(\mathbf{0},\mathbf{I})$ is the stochastic term and the step-size $\epsilon_j$ is typically polynomially decayed~\citep{welling2011bayesian}. The term $\mathcal{E}(\mathbf{x}_j|\mathbf{y}^+)$ denotes the model energy conditioned on the target class label $\mathbf{y}^+$ which we specify as the negative logit corresponding to the target class label $\mathbf{y}^{+}$. To allow for faster sampling, we follow the common practice of choosing the step-size $\epsilon_j$ and the standard deviation of $\mathbf{r}_j$ separately. While $\mathbf{x}_J$ is only guaranteed to distribute as $p_{\theta}(\mathbf{x}|\mathbf{y}^{+})$ if $\epsilon \rightarrow 0$ and $J \rightarrow \infty$, the bias introduced for a small finite $\epsilon$ is negligible in practice \citep{murphy2023probabilistic,grathwohl2020your}. Appendix~\ref{app:jem} provides additional implementation details for any tasks related to energy-based modelling.
where $\mathbf{r}_j \sim \mathcal{N}(\mathbf{0},\mathbf{I})$ is the stochastic term and the step-size $\epsilon_j$ is typically polynomially decayed~\citep{welling2011bayesian}. The term $\mathcal{E}_{\theta}(\mathbf{x}_j|\mathbf{y}^+)$ denotes the model energy conditioned on the target class label $\mathbf{y}^+$ which we specify as the negative logit corresponding to the target class label $\mathbf{y}^{+}$. To allow for faster sampling, we follow the common practice of choosing the step-size $\epsilon_j$ and the standard deviation of $\mathbf{r}_j$ separately. While $\mathbf{x}_J$ is only guaranteed to distribute as $p_{\theta}(\mathbf{x}|\mathbf{y}^{+})$ if $\epsilon \rightarrow 0$ and $J \rightarrow \infty$, the bias introduced for a small finite $\epsilon$ is negligible in practice \citep{murphy2023probabilistic,grathwohl2020your}. Appendix~\ref{app:jem} provides additional implementation details for any tasks related to energy-based modelling.
Generating multiple samples using SGLD thus yields an empirical distribution $\widehat{\mathbf{X}}_{\theta,\mathbf{y}^+}$ that approximates what the model has learned about the input data. While in the context of EBM, this is usually done during training, we propose to repurpose this approach during inference in order to evaluate and generate faithful model explanations.
......@@ -153,7 +153,7 @@ We begin by stating our proposed objective function, which involves tailored los
\begin{equation} \label{eq:eccco}
\begin{aligned}
\mathbf{Z}^\prime &= \arg \min_{\mathbf{Z}^\prime \in \mathcal{Z}^L} \{ {\text{yloss}(M_{\theta}(f(\mathbf{Z}^\prime)),\mathbf{y}^+)}+ \lambda_{1} {\text{dist}(f(\mathbf{Z}^\prime),\mathbf{x}) } \\
&+ \lambda_2 \Delta\mathcal{E}(\mathbf{Z}^\prime,\widehat{\mathbf{X}}_{\theta,\mathbf{y}^+}) + \lambda_3 \Omega(C_{\theta}(f(\mathbf{Z}^\prime);\alpha)) \}
&+ \lambda_2 \Delta\mathcal{E}_{\theta}(\mathbf{Z}^\prime,\widehat{\mathbf{X}}_{\theta,\mathbf{y}^+}) + \lambda_3 \Omega(C_{\theta}(f(\mathbf{Z}^\prime);\alpha)) \}
\end{aligned}
\end{equation}
......@@ -161,7 +161,7 @@ The first penalty term involving $\lambda_1$ induces proximity like in~\citet{wa
\begin{equation} \label{eq:energy-delta}
\begin{aligned}
\Delta\mathcal{E}&=\mathcal{E}(f(\mathbf{Z}^\prime)|\mathbf{y}^+)-\mathcal{E}(x|\mathbf{y}^+) &&& x \sim \widehat{\mathbf{X}}_{\theta,\mathbf{y}^+}
\Delta\mathcal{E}_{\theta}&=\mathcal{E}_{\theta}(f(\mathbf{Z}^\prime)|\mathbf{y}^+)-\mathcal{E}_{\theta}(x|\mathbf{y}^+) &&& x \sim \widehat{\mathbf{X}}_{\theta,\mathbf{y}^+}
\end{aligned}
\end{equation}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment