Skip to content
Snippets Groups Projects
Commit df4d31fb authored by pat-alt's avatar pat-alt
Browse files

new illustrative chart

parent 1790b8b3
No related branches found
No related tags found
No related merge requests found
artifacts/results/images/poc_gradient_fields.png

146 KiB

......@@ -2,7 +2,7 @@
julia_version = "1.8.5"
manifest_format = "2.0"
project_hash = "5c4085099192d26cfd97f9867471a3d11a289671"
project_hash = "a2a8b27b3c8c411d9bf595480742a6cbff281867"
[[deps.AbstractFFTs]]
deps = ["ChainRulesCore", "LinearAlgebra"]
......
......@@ -5,6 +5,7 @@ CairoMakie = "13f3f980-e62b-5c42-98c6-ff1f3baf88f0"
CategoricalDistributions = "af321ab8-2d2e-40a6-b165-3d674595d28e"
Chain = "8be319e6-bccf-4806-a6f7-6fae938471bc"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Colors = "5ae59095-9a9b-59fe-a467-6f913c188581"
ConformalPrediction = "98bfc277-1877-43dc-819b-a3e38c30242f"
CounterfactualExplanations = "2f13d31b-18db-44c1-bc43-ebaf2cff0be0"
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
......
......@@ -154,6 +154,98 @@ display(plt)
savefig(plt, joinpath(output_images_path, "poc.png"))
```
```{julia}
using Colors
col_pal = palette(:seaborn_colorblind)
Random.seed!(1234)
using CounterfactualExplanations.Generators: ∇
λ₁ = 0.1
λ₂ = 1.0
λ₃ = 0.5
Λ = [λ₁, λ₂, λ₃]
η = 0.01
M = ECCCo.ConformalModel(mach.model, mach.fitresult)
factual_label = levels(labels)[2]
x_factual = reshape(X[:,rand(findall(predict_label(M, counterfactual_data).==factual_label))],input_dim,1)
target = levels(labels)[1]
factual = predict_label(M, counterfactual_data, x_factual)[1]
opt = Flux.Optimise.Descent(η)
generator_dict = OrderedDict(
"Wachter" => WachterGenerator(λ = 0.3, opt=opt),
"ECCCo (no EBM)" => ECCCoGenerator(λ = [λ₁,λ₂,0.0], opt=opt),
"ECCCo (no CP)" => ECCCoGenerator(λ = [λ₁,0.0,λ₃], opt=opt),
"ECCCo" => ECCCoGenerator(λ = Λ, opt=opt),
)
# Gradient field:
function loss_grads(generator, model, ce, x)
x = Float32.(x)
_ce = deepcopy(ce)
_ce.s′ = x
return ∇(generator,M,_ce)
end
meshgrid(x, y) = (repeat(x, outer=length(y)), repeat(y, inner=length(x)))
xlims, ylims = extrema(X, dims=2)
xrange = range(xlims..., length=10)
yrange = range(ylims..., length=10)
x1, x2 = meshgrid(xrange, yrange)
inputs = zip(x1, x2)
# Plot:
ces = Dict{Any,Any}()
plts = []
for (name, generator) in generator_dict
# CE:
ce = generate_counterfactual(
x_factual, target, counterfactual_data, M, generator;
initialization=:identity,
converge_when=:generator_conditions,
)
# Main plot (path):
plt = Plots.plot(
ce, title=name, alpha=0.1, cbar=false,
axis=nothing, length_out=10, contour_alpha=0.0,
legend = false,
palette = col_pal,
)
# Generated samples:
if name ∈ ["ECCCo","ECCCo (no CP)"]
_X = distance_from_energy(ce, return_conditionals=true)
Plots.scatter!(
_X[1,:],_X[2,:], color=col_pal[end-1], shape=:star5,
ms=10, label="x̂|$target", alpha=0.1
)
end
# Gradient field:
u = []
v = []
for (x, y) in inputs
g = -loss_grads(generator, M, ce, [x, y][:,:])
push!(u, η * g[1])
push!(v, η * g[2])
end
Plots.quiver!(x1, x2, quiver=(u, v), color=col_pal[5])
push!(plts, plt)
ces[name] = ce
end
plt = Plots.plot(plts..., size=(500,520))
display(plt)
savefig(plt, joinpath(output_images_path, "poc_gradient_fields.png"))
```
```{julia}
ce = ces["ECCCo"]
......@@ -206,7 +298,7 @@ inputs = zip(x1, x2)
u = []
v = []
scale = 0.1
for (x, y) in inputs
push!(u, scale * gradient(f1, [x, y][:,:])[1][1])
push!(v, scale * gradient(f1, [x, y][:,:])[1][2])
......
No preview for this file type
......@@ -312,7 +312,7 @@ We first briefly describe our experimental setup, before presenting our main res
To assess and benchmark the performance of ECCCo against the state of the art, we generate multiple counterfactuals for different black-box models and datasets. In particular, we compare ECCCo to the following counterfactual generators that were introduced above: firstly; Schut~\citep{schut2021generating}, which works under the premise of minimizing predictive uncertainty; secondly, REVISE~\citep{joshi2019realistic}, which is state-of-the-art with respect to plausibility; and, finally, Wachter~\citep{wachter2017counterfactual}, which serves as our baseline. We also consider two variations of ECCCo: `ECCCo (no CP)' involves no set size penalty ($\lambda_3=0$ in Equation~\ref{eq:eccco}), while `ECCCo (no EBM)' does not penalise the distance to samples generated through SGLD ($\lambda_2=0$ in Equation~\ref{eq:eccco}). These have been added to gain some sense of the degree to which the two components underlying ECCCo---namely energy-based modelling (EBM) and conformal prediction (CP)---drive the results.
We use both synthetic and real-world datasets from different domains, all of which are publically available and commonly used to train and benchmark classification algorithms. We synthetically generate a dataset containing two \textbf{Linearly Separable} Gaussian clusters ($n=1000$), as well as the well-known \textbf{Circles} ($n=1000$) and \textbf{Moons} ($n=2500$) data. Since these data are generated by distributions of varying degrees of complexity, they allow us to gauge how complexity affects the different generators.
We use both synthetic and real-world datasets from different domains, all of which are publically available and commonly used to train and benchmark classification algorithms. We synthetically generate a dataset containing two \textbf{Linearly Separable} Gaussian clusters ($n=1000$), as well as the well-known \textbf{Circles} ($n=1000$) and \textbf{Moons} ($n=2500$) data. Since these data are generated by distributions of varying degrees of complexity, they allow us to assess how the generators and our proposed evaluation metrics handle this.
As for real-world data, we follow~\citet{schut2021generating} and use the \textbf{MNIST}~\citep{lecun1998mnist} dataset containing images of handwritten digits such as the examples shown above. From the social sciences domain, we include Give Me Some Credit (\textbf{GMSC})~\citep{kaggle2011give}: a tabular dataset that has been studied extensively in the literature on Algorithmic Recourse~\citep{pawelczyk2021carla}. It consists of 11 numeric features that can be used to predict the binary outcome variable indicating whether or not retail borrowers experience financial distress.
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment