Skip to content
Snippets Groups Projects
proposal.qmd 14.16 KiB
---
title: High-Fidelity Counterfactual Explanations through Conformal Prediction
subtitle: Research Proposal
abstract: |
    We propose Conformal Counterfactual Explanations: an effortless and rigorous way to produce realistic and faithful Counterfactual Explanations using Conformal Prediction. To address the need for realistic counterfactuals, existing work has primarily relied on separate generative models to learn the data-generating process. While this is an effective way to produce plausible and model-agnostic counterfactual explanations, it not only introduces a significant engineering overhead but also reallocates the task of creating realistic model explanations from the model itself to the generative model. Recent work has shown that there is no need for any of this when working with probabilistic models that explicitly quantify their own uncertainty. Unfortunately, most models used in practice still do not fulfil that basic requirement, in which case we would like to have a way to quantify predictive uncertainty in a post-hoc fashion.
---

```{julia}
include("notebooks/setup.jl")
eval(setup_notebooks)
```

## Motivation

Counterfactual Explanations are a powerful, flexible and intuitive way to not only explain black-box models but also enable affected individuals to challenge them through the means of Algorithmic Recourse. 

### Counterfactual Explanations or Adversarial Examples?

Most state-of-the-art approaches to generating Counterfactual Explanations (CE) rely on gradient descent in the feature space. The key idea is to perturb inputs $x\in\mathcal{X}$ into a black-box model $f: \mathcal{X} \mapsto \mathcal{Y}$ in order to change the model output $f(x)$ to some pre-specified target value $t\in\mathcal{Y}$. Formally, this boils down to defining some loss function $\ell(f(x),t)$ and taking gradient steps in the minimizing direction. The so-generated counterfactuals are considered valid as soon as the predicted label matches the target label. A stripped-down counterfactual explanation is therefore little different from an adversarial example. In @fig-adv, for example, generic counterfactual search as in @wachter2017counterfactual has been applied to MNIST data.

```{julia}
# Data:
counterfactual_data = load_mnist()
X, y = CounterfactualExplanations.DataPreprocessing.unpack_data(counterfactual_data)
input_dim, n_obs = size(counterfactual_data.X)
M = load_mnist_mlp()

# Target:
factual_label = 8
x = reshape(X[:,rand(findall(predict_label(M, counterfactual_data).==factual_label))],input_dim,1)
target = 3
factual = predict_label(M, counterfactual_data, x)[1]
γ = 0.9
T = 50
```

```{julia}
# Search:
generator = GenericGenerator()
ce_wachter = generate_counterfactual(
    x, target, counterfactual_data, M, generator; 
    decision_threshold=γ, max_iter=T,
    initialization=:identity,
)
generator = GreedyGenerator(η=5.0)
ce_jsma = generate_counterfactual(
    x, target, counterfactual_data, M, generator; 
    decision_threshold=γ, max_iter=T,
    initialization=:identity,
)
```

```{julia}
p1 = plot(
    convert2image(MNIST, reshape(x,28,28)),
    axis=nothing, 
    size=(img_height, img_height),
    title="Factual"
)
plts = [p1]

ces = zip([ce_wachter,ce_jsma])
counterfactuals = reduce((x,y)->cat(x,y,dims=3),map(ce -> CounterfactualExplanations.counterfactual(ce[1]), ces))
phat = reduce((x,y) -> cat(x,y,dims=3), map(ce -> target_probs(ce[1]), ces))
for x in zip(eachslice(counterfactuals; dims=3), eachslice(phat; dims=3))
    ce, _phat = (x[1],x[2])
    _title = "p(y=$(target)|x′)=$(round(_phat[1]; digits=3))"
    plt = plot(
        convert2image(MNIST, reshape(ce,28,28)),
        axis=nothing, 
        size=(img_height, img_height),
        title=_title
    )
    plts = [plts..., plt]
end
plt = plot(plts...; size=(img_height*length(plts),img_height), layout=(1,length(plts)))
savefig(plt, joinpath(www_path, "you_may_not_like_it.png"))
```

![You may not like it, but this is what stripped-down counterfactuals look like. Here we have used @wachter2017counterfactual to generate multiple counterfactuals for turning an 8 (eight) into a 3 (three).](www/you_may_not_like_it.png){#fig-adv}

The crucial difference between adversarial examples and counterfactuals is one of intent. While adversarial examples are typically intended to go unnoticed, counterfactuals in the context of Explainable AI are generally sought to be "plausible", "realistic" or "feasible". To fulfil this latter goal, researchers have come up with a myriad of ways. @joshi2019realistic were among the first to suggest that instead of searching counterfactuals in the feature space, we can instead traverse a latent embedding learned by a surrogate generative model. Similarly, @poyiadzi2020face use density ... Finally, @karimi2021algorithmic argues that counterfactuals should comply with the causal model that generates them [CHECK IF WE CAN PHASE THIS LIKE THIS]. Other related approaches include ... All of these different approaches have a common goal: they aim to ensure that the generated counterfactuals comply with the (learned) data-generating process (DGB). 

::: {#def-plausible}

## Plausible Counterfactuals

Formally, if $x \sim \mathcal{X}$ and for the corresponding counterfactual we have $x^{\prime}\sim\mathcal{X}^{\prime}$, then for $x^{\prime}$ to be considered a plausible counterfactual, we need: $\mathcal{X} \approxeq \mathcal{X}^{\prime}$.

:::

In the context of Algorithmic Recourse, it makes sense to strive for plausible counterfactuals, since anything else would essentially require individuals to move to out-of-distribution states. But it is worth noting that our ambition to meet this goal, may have implications on our ability to faithfully explain the behaviour of the underlying black-box model (arguably our principal goal). By essentially decoupling the task of learning plausible representations of the data from the model itself, we open ourselves up to vulnerabilities. Using a separate generative model to learn $\mathcal{X}$, for example, has very serious implications for the generated counterfactuals. @fig-latent compares the results of applying REVISE [@joshi2019realistic] to MNIST data using two different Variational Auto-Encoders: while the counterfactual generated using an expressive (strong) VAE is compelling, the result relying on a less expressive (weak) VAE is not even valid. In this latter case, the decoder step of the VAE fails to yield values in $\mathcal{X}$ and hence the counterfactual search in the learned latent space is doomed. 

```{julia}
using CounterfactualExplanations.Models: load_mnist_vae
vae = load_mnist_vae()
vae_weak = load_mnist_vae(;strong=false)
Serialization.serialize(joinpath(output_path,"mnist_classifier.jls"), M)
Serialization.serialize(joinpath(output_path,"mnist_vae.jls"), vae)
Serialization.serialize(joinpath(output_path,"mnist_vae_weak.jls"), vae_weak)
```

```{julia}
# Define generator:
generator = REVISEGenerator(
  opt = Descent(0.1),
  λ=0.01
)
# Generate recourse:
counterfactual_data.generative_model = vae # assign generative model
ce = generate_counterfactual(
    x, target, counterfactual_data, M, generator; 
    decision_threshold=γ, max_iter=T,
    initialization=:identity,
)
counterfactual_data = deepcopy(counterfactual_data)
counterfactual_data.generative_model = vae_weak
ce_weak = generate_counterfactual(
    x, target, counterfactual_data, M, generator;
    decision_threshold=γ, max_iter=T,
    initialization=:identity,
)
```

```{julia}
ces = zip([ce,ce_weak])
counterfactuals = reduce((x,y)->cat(x,y,dims=3),map(ce -> CounterfactualExplanations.counterfactual(ce[1]), ces))
phat = reduce((x,y) -> cat(x,y,dims=3), map(ce -> target_probs(ce[1]), ces))
plts = [p1]
for x in zip(eachslice(counterfactuals; dims=3), eachslice(phat; dims=3))
    ce, _phat = (x[1],x[2])
    _title = "p(y=$(target)|x′)=$(round(_phat[1]; digits=3))"
    plt = plot(
        convert2image(MNIST, reshape(ce,28,28)),
        axis=nothing, 
        size=(img_height, img_height),
        title=_title
    )
    plts = [plts..., plt]
end
plt = plot(plts...; size=(img_height*length(plts),img_height), layout=(1,length(plts)))
savefig(plt, joinpath(www_path, "surrogate_gone_wrong.png"))
```

![Counterfactual explanations for MNIST using a Latent Space generator: turning a nine (9) into a four (4).](www/mnist_9to4_latent.png){#fig-latent}

> Here it would be nice to have another example where we poison the data going into the generative model to hide biases present in the data (e.g. Boston housing).

- Latent can be manipulated: 
    - train biased model
    - train VAE with biased variable removed/attacked (use Boston housing dataset)
    - hypothesis: will generate bias-free explanations

### From Plausible to High-Fidelity Counterfactuals {#sec-fidelity}

In light of the findings, we propose to generally avoid using surrogate models to learn $\mathcal{X}$ in the context of Counterfactual Explanations.

::: {#prp-surrogate}

## Avoid Surrogates

Since we are in the business of explaining a black-box model, the task of learning realistic representations of the data should not be reallocated from the model itself to some surrogate model.

:::

In cases where the use of surrogate models cannot be avoided, we propose to weigh the plausibility of counterfactuals against their fidelity to the black-box model. In the context of Explainable AI, fidelity is defined as describing how an explanation approximates the prediction of the black-box model [@molnar2020interpretable]. Fidelity has become the default metric for evaluating Local Model-Agnostic Models, since they often involve local surrogate models whose predictions need not always match those of the black-box model. 

In the case of Counterfactual Explanations, the concept of fidelity has so far been ignored. This is not altogether surprising, since by construction and design, Counterfactual Explanations work with the predictions of the black-box model directly: as stated above, a counterfactual $x^{\prime}$ is considered valid if and only if $f(x^{\prime})=t$, where $t$ denote some target outcome. 

Does fidelity even make sense in the context of CE, and if so, how can we define it? In light of the examples in the previous section, we think it is urgent to introduce a notion of fidelity in this context, that relates to the distributional properties of the generated counterfactuals. In particular, we propose that a high-fidelity counterfactual $x^{\prime}$ complies with the class-conditional distribution $\mathcal{X}_{\theta} = p_{\theta}(X|y)$ where $\theta$ denote the black-box model parameters. 

::: {#def-fidele}

## High-Fidelity Counterfactuals

Let $\mathcal{X}_{\theta}|y = p_{\theta}(X|y)$ denote the class-conditional distribution of $X$ defined by $\theta$. Then for $x^{\prime}$ to be considered a high-fidelity counterfactual, we need: $\mathcal{X}_{\theta}|t \approxeq \mathcal{X}^{\prime}$ where $t$ denotes the target outcome.

:::

In order to assess the fidelity of counterfactuals, we propose the following two-step procedure:

1) Generate samples $X_{\theta}|y$ and $X^{\prime}$ from $\mathcal{X}_{\theta}|t$ and $\mathcal{X}^{\prime}$, respectively.
2) Compute the Maximum Mean Discrepancy (MMD) between $X_{\theta}|y$ and $X^{\prime}$. 

If the computed value is different from zero, we can reject the null-hypothesis of fidelity.

> Two challenges here: 1) implementing the sampling procedure in @grathwohl2020your; 2) it is unclear if MMD is really the right way to measure this. 

## Conformal Counterfactual Explanations

In @sec-fidelity, we have advocated for avoiding surrogate models in the context of Counterfactual Explanations. In this section, we introduce an alternative way to generate high-fidelity Counterfactual Explanations. In particular, we propose Conformal Counterfactual Explanations (CCE), that is Counterfactual Explanations that minimize the predictive uncertainty of conformal models. 

### Minimizing Predictive Uncertainty

@schut2021generating demonstrated that the goal of generating realistic (plausible) counterfactuals can also be achieved by seeking counterfactuals that minimize the predictive uncertainty of the underlying black-box model. Similarly, @antoran2020getting ...

- Problem: restricted to Bayesian models.
- Solution: post-hoc predictive uncertainty quantification. In particular, Conformal Prediction. 

### Background on Conformal Prediction

- Distribution-free, model-agnostic and scalable approach to predictive uncertainty quantification.
- Conformal prediction is instance-based. So is CE. 
- Take any fitted model and turn it into a conformal model using calibration data.
- Our approach, therefore, relaxes the restriction on the family of black-box models, at the cost of relying on a subset of the data. Arguably, data is often abundant and in most applications practitioners tend to hold out a test data set anyway. 

> Does the coverage guarantee carry over to counterfactuals?

### Generating Conformal Counterfactuals
While Conformal Prediction has recently grown in popularity, it does introduce a challenge in the context of classification: the predictions of Conformal Classifiers are set-valued and therefore difficult to work with, since they are, for example, non-differentiable. Fortunately, @stutz2022learning introduced carefully designed differentiable loss functions that make it possible to evaluate the performance of conformal predictions in training. We can leverage these recent advances in the context of gradient-based counterfactual search ...

> Challenge: still need to implement these loss functions. 

## Experiments

### Research Questions

- Is CP alone enough to ensure realistic counterfactuals?
- Do counterfactuals improve further as the models get better?
- Do counterfactuals get more realistic as coverage
- What happens as we vary coverage and setsize?
- What happens as we improve the model robustness?
- What happens as we improve the model's ability to incorporate predictive uncertainty (deep ensemble, laplace)?
- What happens if we combine with DiCE, ClaPROAR, Gravitational?
- What about CE robustness to endogenous shifts [@altmeyer2023endogenous]?

- Benchmarking:
    - add PROBE [@pawelczyk2022probabilistically] into the mix.
    - compare travel costs to domain shits.

> Nice to have: What about using Laplace Approximation, then Conformal Prediction? What about using Conformalised Laplace? 

## References