Skip to content
Snippets Groups Projects
Commit cfd5467b authored by pat-alt's avatar pat-alt
Browse files

addressed cynthia's comments

parents f5b04ee9 b1e4fe8e
No related branches found
No related tags found
No related merge requests found
artifacts/results/images/mnist_benchmark.png

191 KiB

This diff is collapsed.
......@@ -148,9 +148,10 @@ _retrain = false
_regen = false
# Data:
n_obs = 1000
n_obs = 10000
counterfactual_data = load_mnist(n_obs)
counterfactual_data.X = pre_process.(counterfactual_data.X)
counterfactual_data.generative_model = vae
X, y = CounterfactualExplanations.DataPreprocessing.unpack_data(counterfactual_data)
X = table(permutedims(X))
x_factual = reshape(pre_process(x_factual, noise=0.0f0), input_dim, 1)
......@@ -186,7 +187,7 @@ sampler = ConditionalSampler(
input_size=(input_dim,),
batch_size=10,
)
α = [1.0,1.0,1e-1] # penalty strengths
α = [1.0,1.0,1e-2] # penalty strengths
```
```{julia}
......@@ -258,6 +259,7 @@ end
```{julia}
# Plot generated samples:
n_regen = 150
if _regen
for (mod_name, mod) in model_dict
if ECCCo._has_sampler(mod)
......@@ -272,12 +274,11 @@ if _regen
opt = ImproperSGLD()
f(x) = logits(mod, x)
n_iter = 200
_w = 1500
plts = []
neach = 10
for i in 1:10
x = sampler(f, opt; niter=n_iter, n_samples=neach, y=i)
x = sampler(f, opt; niter=n_regen, n_samples=neach, y=i)
plts_i = []
for j in 1:size(x, 2)
xj = x[:,j]
......@@ -321,7 +322,7 @@ model_performance
```{julia}
function _plot_eccco_mnist(
x::Union{AbstractArray, Int}=x_factual, target::Int=target;
λ=[0.1,0.1,0.1],
λ=[0.5,0.1,0.5],
temp=0.1,η=0.01,
plt_order = ["MLP", "MLP Ensemble", "JEM", "JEM Ensemble"],
opt = Flux.Optimise.Adam(η),
......@@ -498,7 +499,7 @@ plts = plts[plt_order]
plts = [p1, plts...]
plt = Plots.plot(plts...; size=(img_height*length(plts),img_height), layout=(1,length(plts)))
display(plt)
savefig(plt, joinpath(output_images_path, "mnist_eccco_benchmark.png"))
savefig(plt, joinpath(output_images_path, "mnist_all_generators.png"))
```
## Benchmark
......@@ -509,7 +510,50 @@ measures = [
CounterfactualExplanations.distance,
ECCCo.distance_from_energy,
ECCCo.distance_from_targets,
CounterfactualExplanations.validity,
CounterfactualExplanations.redudancy,
CounterfactualExplanations.Evaluation.validity,
CounterfactualExplanations.Evaluation.redundancy,
]
bmk = benchmark(
counterfactual_data;
models=model_dict,
generators=generator_dict,
measure=measures,
suppress_training=true, dataname="MNIST",
n_individuals=5,
factual=0, target=1,
initialization=:identity,
)
CSV.write(joinpath(output_path, "mnist_benchmark.csv"), bmk())
```
```{julia}
@chain bmk() begin
@group_by(dataname, generator, model, variable)
@summarize(mean=mean(value),sd=std(value))
@ungroup
@filter(variable == "distance_from_energy")
end
```
```{julia}
df = @chain bmk() begin
@filter(variable in [
"distance_from_energy",
"distance_from_targets",
"distance",])
@mutate(variable = ifelse.(variable .== "distance_from_energy", "Non-Conformity", variable))
@mutate(variable = ifelse.(variable .== "distance_from_targets", "Implausibility", variable))
@mutate(variable = ifelse.(variable .== "distance", "Cost", variable))
end
plt = AlgebraOfGraphics.data(df) * visual(BoxPlot) *
mapping(:generator, :value, row=:variable, col=:model, color=:generator)
plt = draw(
plt, axis=(xlabel="", xticksvisible=false, xticklabelsvisible=false, width=150, height=120),
facet=(; linkyaxes=:minimal)
)
display(plt)
save(joinpath(output_images_path, "mnist_benchmark.png"), plt, px_per_unit=5)
```
\ No newline at end of file
No preview for this file type
......@@ -246,7 +246,7 @@ The first two terms in Equation~\ref{eq:eccco} correspond to the counterfactual
\begin{minipage}[c]{0.40\textwidth}
\centering
\includegraphics[width=\textwidth]{../artifacts/results/images/eccco_illustration.png}
\captionof{figure}{Vector fields indicating the direction of gradients with respect to the different components of the ECCCo objective (Equation~\ref{eq:eccco}).} \label{fig:eccco}
\captionof{figure}{[PLACEHOLDER] Vector fields indicating the direction of gradients with respect to the different components of the ECCCo objective (Equation~\ref{eq:eccco}).} \label{fig:eccco}
\end{minipage}
\hfill
\begin{minipage}[c]{0.50\textwidth}
......@@ -271,7 +271,7 @@ The first two terms in Equation~\ref{eq:eccco} correspond to the counterfactual
\begin{minipage}[c]{\textwidth}
\includegraphics[width=\textwidth]{../artifacts/results/images/mnist_eccco.png}
\captionof{figure}{Original image (left) and ECCCos for turning an 8 (eight) into a 3 (three) for different Black Boxes from left to right: Multi-Layer Perceptron (MLP), Ensemble of MLPs, Joint Energy Model (JEM), Ensemble of JEMs.}\label{fig:eccco-mnist}
\captionof{figure}{[SUBJECTO TO CHANGE] Original image (left) and ECCCos for turning an 8 (eight) into a 3 (three) for different Black Boxes from left to right: Multi-Layer Perceptron (MLP), Ensemble of MLPs, Joint Energy Model (JEM), Ensemble of JEMs.}\label{fig:eccco-mnist}
\end{minipage}
\medskip
......@@ -282,7 +282,7 @@ Finally, we search counterfactuals through gradient descent. Let $\mathcal{L}(\m
Figure~\ref{fig:eccco-mnist} presents ECCCos for the MNIST example from Section~\ref{background} for various Black Box models of increasing complexity from left to right: a simple Multi-Layer Perceptron (MLP); an Ensemble of MLPs, each of the same architecture as the single MLP; a Joint Energy Model (JEM) based on the same MLP architecture; and finally, an Ensemble of these JEMs. Since Deep Ensembles have an improved capacity for predictive uncertainty quantification and JEMs are explicitly trained to learn plausible representations of the input data, it is intuitive to see that the plausibility of counterfactuals visibly improves from left to right. This provides some first anecdotal evidence that ECCCos achieve plausibility while maintaining faithfulness to the Black Box.
\section{Experiments}\label{conformity}
\section{Empirical Analysis}\label{emp}
In this section, we bolster our anecdotal findings from the previous section through rigorous empirical analysis. We first briefly describe our evaluation framework and data, before presenting and discussing our results.
......@@ -308,9 +308,22 @@ This measure is straightforward to compute and should be less sensitive to outli
As noted by \citet{guidotti2022counterfactual}, these distance-based measures are simplistic and more complex alternative measures may ultimately be more appropriate for the task. For example, we considered using statistical divergence measures instead. This would involve generating not one but many counterfactuals and comparing the generated empirical distribution to the target distributions in Definitions~\ref{def:plausible} and~\ref{def:conformal}. While this approach is potentially more rigorous, generating enough counterfactuals is not always practical.
\section{Experiments}
\subsection{Data}
\subsection{Results}
\begin{figure}
\includegraphics[width=\textwidth]{../artifacts/results/images/mnist_benchmark.png}
\caption{[SUBJECTO TO CHANGE] Original image (left) and ECCCos for turning an 8 (eight) into a 3 (three) for different Black Boxes from left to right: Multi-Layer Perceptron (MLP), Ensemble of MLPs, Joint Energy Model (JEM), Ensemble of JEMs.}\label{fig:mnist-benchmark}
\end{figure}
\section{Discussion}
\subsection{Key Insights}
Consistent with the findings in \citet{schut2021generating}, we have demonstrated that predictive uncertainty estimates can be leveraged to generate plausible counterfactuals. Interestingly, \citet{schut2021generating} point out that this finding --- as intuitive as it is --- may be linked to a positive connection between the generative task and predictive uncertainty quantification. In particular, \citet{grathwohl2020your} demonstrate that their proposed method for integrating the generative objective in training yields models that have improved predictive uncertainty quantification. Since neither \citet{schut2021generating} nor we have employed any surrogate generative models, our findings seem to indicate that the positive connection found in \citet{grathwohl2020your} is bidirectional.
\subsection{Limitations}
\begin{itemize}
\item BatchNorm does not seem compatible with JEM
......@@ -323,10 +336,7 @@ As noted by \citet{guidotti2022counterfactual}, these distance-based measures ar
\item For MNIST it seems that ECCCo is better at reducing pixel values than increasing them (better at erasing than writing)
\end{itemize}
\section{Discussion}
Consistent with the findings in \citet{schut2021generating}, we have demonstrated that predictive uncertainty estimates can be leveraged to generate plausible counterfactuals. Interestingly, \citet{schut2021generating} point out that this finding --- as intuitive as it is --- may be linked to a positive connection between the generative task and predictive uncertainty quantification. In particular, \citet{grathwohl2020your} demonstrate that their proposed method for integrating the generative objective in training yields models that have improved predictive uncertainty quantification. Since neither \citet{schut2021generating} nor we have employed any surrogate generative models, our findings seem to indicate that the positive connection found in \citet{grathwohl2020your} is bidirectional.
\section{Conclusion}
\medskip
......
......@@ -47,7 +47,7 @@ function distance_from_energy(
_dict[:energy_sampler] = ECCCo.EnergySampler(ce; niter=niter, nsamples=n, kwargs...)
end
sampler = _dict[:energy_sampler]
push!(conditional_samples, rand(sampler, 100; from_buffer=from_buffer))
push!(conditional_samples, rand(sampler, n; from_buffer=from_buffer))
end
x′ = CounterfactualExplanations.counterfactual(ce)
loss = map(eachslice(x′, dims=ndims(x′))) do x
......@@ -70,10 +70,9 @@ function distance_from_targets(
target_samples = ce.data.X[:,target_idx] |>
X -> X[:,rand(1:end,n)]
x′ = CounterfactualExplanations.counterfactual(ce)
loss = map(eachslice(x′, dims=3)) do x
x = Matrix(x)
loss = map(eachslice(x′, dims=ndims(x′))) do x
Δ = map(eachcol(target_samples)) do xsample
norm(x - xsample)
norm(x - xsample, 1)
end
return mean(Δ)
end
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment