Skip to content
Snippets Groups Projects
Commit 4f0c8894 authored by pat-alt's avatar pat-alt
Browse files

slowly slowly

parent f4692408
No related branches found
No related tags found
No related merge requests found
```{julia}
using CCE
using CCE: set_size_penalty, distance_from_energy, distance_from_targets
using ConformalPrediction
using CounterfactualExplanations
using CounterfactualExplanations.Data
......@@ -11,34 +12,208 @@ using LinearAlgebra
using MLJBase
using MLJFlux
using Plots
img_height = 300
```
# Fidelity Measures
# `ConformalGenerator`
## Binary
In this section, we will look at a simple example involving synthetic data, a black-box model and a generic Conformal Counterfactual Generator.
## Black-box Model
We consider a simple binary classification problem. Let $(X_i, Y_i), \ i=1,...,n$ denote our feature-label pairs and let $\mu: \mathcal{X} \mapsto \mathcal{Y}$ denote the mapping from features to labels. For illustration purposes, we will use linearly separable data.
```{julia}
# Setup
counterfactual_data = load_linearly_separable()
M = fit_model(counterfactual_data, :DeepEnsemble)
target = 2
factual = 1
chosen = rand(findall(predict_label(M, counterfactual_data) .== factual))
x = select_factual(counterfactual_data, chosen)
```
# Search:
generator = GenericGenerator(opt=Descent(0.01))
ce = generate_counterfactual(x, target, counterfactual_data, M, generator)
While we could use a linear classifier in this case, let's pretend we need a black-box model for this task and rely on a small Multi-Layer Perceptron (MLP):
```{julia}
builder = MLJFlux.@builder Chain(
Dense(n_in, 32, relu),
Dense(32, n_out)
)
clf = NeuralNetworkClassifier(builder=builder, epochs=100)
```
We can fit this model to data to produce plug-in predictions.
## Conformal Prediction
Here we will instead use a specific case of CP called *split conformal prediction* which can then be summarized as follows:^[In other places split conformal prediction is sometimes referred to as *inductive* conformal prediction.]
1. Partition the training into a proper training set and a separate calibration set: $\mathcal{D}_n=\mathcal{D}^{\text{train}} \cup \mathcal{D}^{\text{cali}}$.
2. Train the machine learning model on the proper training set: $\hat\mu_{i \in \mathcal{D}^{\text{train}}}(X_i,Y_i)$.
The model $\hat\mu_{i \in \mathcal{D}^{\text{train}}}$ can now produce plug-in predictions.
::: callout-note
## Starting Point
Note that this represents the starting point in applications of Algorithmic Recourse: we have some pre-trained classifier $M$ for which we would like to generate plausible Counterfactual Explanations. Next, we turn to the calibration step.
:::
3. Compute nonconformity scores, $\mathcal{S}$, using the calibration data $\mathcal{D}^{\text{cali}}$ and the fitted model $\hat\mu_{i \in \mathcal{D}^{\text{train}}}$.
4. For a user-specified desired coverage ratio $(1-\alpha)$ compute the corresponding quantile, $\hat{q}$, of the empirical distribution of nonconformity scores, $\mathcal{S}$.
5. For the given quantile and test sample $X_{\text{test}}$, form the corresponding conformal prediction set:
$$
C(X_{\text{test}})=\{y:s(X_{\text{test}},y) \le \hat{q}\}
$$ {#eq-set}
This is the default procedure used for classification and regression in [`ConformalPrediction.jl`](https://github.com/pat-alt/ConformalPrediction.jl).
Using the package, we can apply Split Conformal Prediction as follows:
```{julia}
X = table(permutedims(counterfactual_data.X))
y = counterfactual_data.output_encoder.labels
conf_model = conformal_model(clf; method=:simple_inductive)
mach = machine(conf_model, X, y)
fit!(mach)
```
To be clear, all of the calibration steps (3 to 5) are post hoc, and yet none of them involved any changes to the model parameters. These are two important characteristics of Split Conformal Prediction (SCP) that make it particularly useful in the context of Algorithmic Recourse. Firstly, the fact that SCP involves posthoc calibration steps that happen after training, ensures that we need not place any restrictions on the black-box model itself. This stands in contrast to the approach proposed by @schut2021generating in which they essentially restrict the class of models to Bayesian models. Secondly, the fact that the model itself is kept entirely intact ensures that the generated counterfactuals maintain fidelity to the model. Finally, note that we also have not resorted to a surrogate model to learn more about $X \sim \mathcal{X}$. Instead, we have used the fitted model itself and a calibration data set to learn about the model's predictive uncertainty.
## Differentiable CP
In order to use CP in the context of gradient-based counterfactual search, we need it to be differentiable. @stutz2022learning introduce a framework for training differentiable conformal predictors. They introduce a configurable loss function as well as smooth set size penalty.
### Smooth Set Size Penalty
Starting with the former, @stutz2022learning propose the following:
$$
\Omega(C_{\theta}(x;\tau)) = = \max (0, \sum_k C_{\theta,k}(x;\tau) - \kappa)
$$ {#eq-size-loss}
Here, $C_{\theta,k}(x;\tau)$ is loosely defined as the probability that class $k$ is assigned to the conformal prediction set $C$. In the context of Conformal Training, this penalty reduces the **inefficiency** of the conformal predictor.
In our context, we are not interested in improving the model itself, but rather in producing **plausible** counterfactuals. Provided that our counterfactual $x^\prime$ is already inside the target domain ($\mathbb{I}_{y^\prime = t}=1$), penalizing $\Omega(C_{\theta}(x;\tau))$ corresponds to guiding counterfactuals into regions of the target domain that are characterized by low ambiguity: for $\kappa=1$ the conformal prediction set includes only the target label $t$ as $\Omega(C_{\theta}(x;\tau))$. Arguably, less ambiguous counterfactuals are more **plausible**. Since the search is guided purely by properties of the model itself and (exchangeable) calibration data, counterfactuals also maintain **high fidelity**.
The left panel of @fig-losses shows the smooth size penalty in the two-dimensional feature space of our synthetic data.
### Configurable Classification Loss
The right panel of @fig-losses shows the configurable classification loss in the two-dimensional feature space of our synthetic data.
```{julia}
#| output: true
#| echo: false
#| label: fig-losses
#| fig-cap: "Illustration of the smooth size loss and the configurable classification loss."
temp = 0.5
p1 = contourf(mach.model, mach.fitresult, X, y; plot_set_loss=true, zoom=0, temp=temp)
p2 = contourf(mach.model, mach.fitresult, X, y; plot_classification_loss=true, target=target, zoom=0, temp=temp, clim=nothing, loss_matrix=ones(2,2))
plot(p1, p2, size=(800,320))
```
## Fidelity and Plausibility
The main evaluation criteria we are interested in are *fidelity* and *plausibility*. Interestingly, we could also consider using these measures as penalties in the counterfactual search.
### Fidelity
We propose to define fidelity as follows:
::: {#def-fidelity}
## High-Fidelity Counterfactuals
Let $\mathcal{X}_{\theta}|y = p_{\theta}(X|y)$ denote the class-conditional distribution of $X$ defined by $\theta$. Then for $x^{\prime}$ to be considered a high-fidelity counterfactual, we need: $\mathcal{X}_{\theta}|t \approxeq \mathcal{X}^{\prime}$ where $t$ denotes the target outcome.
:::
We can generate samples from $p_{\theta}(X|y)$ following @grathwohl2020your. In @fig-energy, I have applied the methodology to our synthetic data.
```{julia}
#| fig-cap: "Energy-based conditional samples."
#| label: fig-energy
niter = 100
nsamples = 100
sampler = CCE.EnergySampler(ce;niter=niter, nsamples=100)
Xgen = rand(sampler, nsamples)
plt = plot(M, counterfactual_data, target=ce.target, xlims=(-5,5),ylims=(-5,5),cbar=false)
scatter!(Xgen[1,:],Xgen[2,:],alpha=0.5,color=target,shape=:star,label="X|y=$target")
plts = []
for target ∈ counterfactual_data.y_levels
sampler = CCE.EnergySampler(M, counterfactual_data, target; niter=niter, nsamples=100)
Xgen = rand(sampler, nsamples)
plt = plot(M, counterfactual_data; target=target, zoom=-3,cbar=false)
scatter!(Xgen[1,:],Xgen[2,:],alpha=0.5,color=target,shape=:star,label="X|y=$target")
push!(plts, plt)
end
plot(plts..., layout=(1,length(plts)), size=(img_height*length(plts),img_height))
```
As an evaluation metric and penalty, we could use the average distance of the counterfactual $x^{\prime}$ from these generated samples, for example.
### Plausibility
We propose to define plausibility as follows:
::: {#def-plausible}
## Plausible Counterfactuals
Formally, let $\mathcal{X}|t$ denote the conditional distribution of samples in the target class. As before, we have $x^{\prime}\sim\mathcal{X}^{\prime}$, then for $x^{\prime}$ to be considered a plausible counterfactual, we need: $\mathcal{X}|t \approxeq \mathcal{X}^{\prime}$.
:::
As an evaluation metric and penalty, we could use the average distance of the counterfactual $x^{\prime}$ from (potentially bootstrapped) training samples in the target class, for example.
## Counterfactual Explanations
Next, let's generate counterfactual explanations for our synthetic data. We first wrap our model in a container that makes it compatible with `CounterfactualExplanations.jl`. Then we draw a random sample, determine its predicted label $\hat{y}$ and choose the opposite label as our target.
```{julia}
M = CCE.ConformalModel(conf_model, mach.fitresult)
x = select_factual(counterfactual_data,rand(1:size(counterfactual_data.X,2)))
y_factual = predict_label(M, counterfactual_data, x)[1]
target = counterfactual_data.y_levels[counterfactual_data.y_levels .!= y_factual][1]
```
The generic Conformal Counterfactual Generator penalises the only the set size only:
$$
x^\prime = \arg \min_{x^\prime} \ell(M(x^\prime),t) + \lambda \mathbb{I}_{y^\prime = t} \Omega(C_{\theta}(x;\tau))
$$ {#eq-solution}
```{julia}
#| output: true
#| echo: false
#| label: fig-ce
#| fig-cap: "Comparison of counterfactuals produced using different generators."
opt = Descent(0.01)
ordered_names = [
"Generic (γ=0.5)",
"Conformal (λ₂=1)",
"Conformal (λ₂=10)"
]
loss_fun = Objectives.logitbinarycrossentropy
generator = GenericGenerator(opt = opt)
# Generators:
generators = Dict(
ordered_names[1] => generator,
ordered_names[2] => deepcopy(generator) |> gen -> @objective(gen, _ + 0.1distance_l2 + 1.0set_size_penalty),
ordered_names[3] => deepcopy(generator) |> gen -> @objective(gen, _ + 0.1distance_l2 + 10.0set_size_penalty),
)
counterfactuals = Dict([name => generate_counterfactual(x, target, counterfactual_data, M, gen; initialization=:identity, converge_when=:generator_conditions, gradient_tol=1e-3) for (name, gen) in generators])
# Plots:
plts = []
for name ∈ ordered_names
ce = counterfactuals[name]
plt = plot(ce; title=name, colorbar=false, ticks = false, legend=false, zoom=0)
plts = vcat(plts..., plt)
end
_n = length(generators)
img_size = 300
plot(plts..., size=(_n * img_size,1.05*img_size), layout=(1,_n))
```
## Multi-Class
......@@ -88,4 +263,5 @@ using CCE: distance_from_targets
)
ce = generate_counterfactual(x, target, counterfactual_data, M, generator)
p3 = plot(ce)
```
\ No newline at end of file
```
......@@ -56,16 +56,16 @@ In the binary case logits are fed through the sigmoid function instead of softma
which follows from the derivation here: https://stats.stackexchange.com/questions/233658/softmax-vs-sigmoid-function-in-logistic-classifier
"""
function Models.logits(M::ConformalModel, X::AbstractArray)
yhat = SliceMap.slicemap(X, dims=(1, 2)) do x
conf_model = M.model
fitresult = M.fitresult
# x = MLJBase.table(permutedims(x))
# p̂ = MMI.predict(conf_model.model, fitresult, MMI.reformat(conf_model.model, x)...)
# p̂ = map(p̂) do pp
# L = p̂.decoder.classes
# probas = pdf.(pp, L)
# return probas
# end
conf_model = M.model
fitresult = M.fitresult
# x = MLJBase.table(permutedims(x))
# = MMI.predict(conf_model.model, fitresult, MMI.reformat(conf_model.model, x)...)
# p̂ = map(p̂) do pp
# L = p̂.decoder.classes
# probas = pdf.(pp, L)
# return probas
# end
function predict_logits(fitresult, x)
= fitresult[1](x)
if ndims() == 2
= []
......@@ -78,6 +78,13 @@ function Models.logits(M::ConformalModel, X::AbstractArray)
= ndims() > 1 ? : permutedims([])
return
end
if ndims(X) > 2
yhat = map(eachslice(X, dims=ndims(X))) do x
predict_logits(fitresult, x)
end
else
yhat = predict_logits(fitresult, X)
end
return yhat
end
......
......@@ -3,42 +3,113 @@ using Distributions
using Flux
using JointEnergyModels
"""
(model::AbstractFittedModel)(x)
When called on data `x`, softmax logits are returned. In the binary case, outputs are one-hot encoded.
"""
(model::AbstractFittedModel)(x) = log.(CounterfactualExplanations.predict_proba(model, nothing, x))
"Base type that stores information relevant to energy-based posterior sampling from `AbstractFittedModel`."
mutable struct EnergySampler
ce::CounterfactualExplanation
model::AbstractFittedModel
data::CounterfactualData
sampler::JointEnergyModels.ConditionalSampler
opt::JointEnergyModels.AbstractSamplingRule
buffer::AbstractArray
buffer::Union{Nothing,AbstractArray}
yidx::Union{Nothing,Any}
end
"""
EnergySampler(
model::AbstractFittedModel,
data::CounterfactualData,
y::Any;
opt::JointEnergyModels.AbstractSamplingRule=ImproperSGLD(),
niter::Int=100,
nsamples::Int=1000
)
Constructor for `EnergySampler` that takes a `model`, `data` and conditioning value `y` as inputs.
"""
function EnergySampler(
ce::CounterfactualExplanation;
model::AbstractFittedModel,
data::CounterfactualData,
y::Any;
opt::JointEnergyModels.AbstractSamplingRule=ImproperSGLD(),
niter::Int=100,
nsamples::Int=1000
)
# Setup:
model = ce.M
data = ce.data
@assert y data.y_levels || y 1:length(data.y_levels)
K = length(data.y_levels)
𝒟x = Normal()
𝒟y = Categorical(ones(K) ./ K)
sampler = ConditionalSampler(𝒟x, 𝒟y)
yidx = get_target_index(data.y_levels, y)
# Initiate:
energy_sampler = EnergySampler(model, data, sampler, opt, nothing, nothing)
# Generate samples:
generate_samples!(energy_sampler, nsamples, yidx; niter=niter)
return energy_sampler
end
"""
generate_samples(e::EnergySampler, n::Int, y::Int; niter::Int=100)
Generates `n` samples from `EnergySampler` for conditioning value `y`.
"""
function generate_samples(e::EnergySampler, n::Int, y::Int; niter::Int=100)
X = e.sampler(e.model, e.opt, (size(e.data.X, 1), n); niter=niter, y=y)
return X
end
# Fit:
i = get_target_index(data.y_levels, ce.target)
buffer = sampler(model, opt, (size(data.X, 1), nsamples); niter=niter, y=i)
"""
generate_samples!(e::EnergySampler, n::Int, y::Int; niter::Int=100)
return EnergySampler(ce, sampler, opt, buffer)
Generates `n` samples from `EnergySampler` for conditioning value `y`. Assigns samples and conditioning value to `EnergySampler`.
"""
function generate_samples!(e::EnergySampler, n::Int, y::Int; niter::Int=100)
e.buffer = generate_samples(e,n,y;niter=niter)
e.yidx = y
end
function Base.rand(sampler::EnergySampler, n::Int=100; retrain=false)
ntotal = size(sampler.buffer,2)
"""
EnergySampler(
ce::CounterfactualExplanation;
kwrgs...
)
Constructor for `EnergySampler` that takes a `CounterfactualExplanation` as input. The underlying model, data and `target` are used for the `EnergySampler`, where `target` is the conditioning value of `y`.
"""
function EnergySampler(
ce::CounterfactualExplanation;
kwrgs...
)
# Setup:
model = ce.M
data = ce.data
y = ce.target
return EnergySampler(model, data, y; kwrgs...)
end
"""
Base.rand(sampler::EnergySampler, n::Int=100; retrain=false)
Overloads the `rand` method to randomly draw `n` samples from `EnergySampler`.
"""
function Base.rand(sampler::EnergySampler, n::Int=100; from_buffer=true, niter::Int=100)
ntotal = size(sampler.buffer, 2)
idx = rand(1:ntotal, n)
if !retrain
X = sampler.buffer[:,idx]
if from_buffer
X = sampler.buffer[:, idx]
else
X = generate_samples(sampler, n, sampler.yidx; niter=niter)
end
return X
end
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment