Skip to content
Snippets Groups Projects
Commit f4692408 authored by pat-alt's avatar pat-alt
Browse files

differentiable penalties for fidelity and plausibility

parent b57c969f
No related branches found
No related tags found
No related merge requests found
......@@ -15,10 +15,38 @@ using Plots
# Fidelity Measures
## Binary
```{julia}
# Setup
counterfactual_data = load_linearly_separable()
M = fit_model(counterfactual_data, :DeepEnsemble)
target = 2
factual = 1
chosen = rand(findall(predict_label(M, counterfactual_data) .== factual))
x = select_factual(counterfactual_data, chosen)
# Search:
generator = GenericGenerator(opt=Descent(0.01))
ce = generate_counterfactual(x, target, counterfactual_data, M, generator)
```
```{julia}
niter = 100
nsamples = 100
sampler = CCE.EnergySampler(ce;niter=niter, nsamples=100)
Xgen = rand(sampler, nsamples)
plt = plot(M, counterfactual_data, target=ce.target, xlims=(-5,5),ylims=(-5,5),cbar=false)
scatter!(Xgen[1,:],Xgen[2,:],alpha=0.5,color=target,shape=:star,label="X|y=$target")
```
## Multi-Class
```{julia}
# Setup
counterfactual_data = load_multi_class()
M = fit_model(counterfactual_data, :MLP)
M = fit_model(counterfactual_data, :DeepEnsemble)
target = 4
factual = 2
chosen = rand(findall(predict_label(M, counterfactual_data) .== factual))
......@@ -45,7 +73,19 @@ p1 = plot(ce)
```{julia}
@objective(generator, _ + 0.1distance_l2 + 100.0distance_from_energy)
using CCE: distance_from_energy
@objective(generator, _ + 0.1distance_l2 + 10.0distance_from_energy)
ce = generate_counterfactual(x, target, counterfactual_data, M, generator)
p2 = plot(ce)
```
```{julia}
using CCE: distance_from_targets
@objective(
generator,
_ + 0.1distance_l2 + 1.0distance_from_energy + 10.0distance_from_targets
)
ce = generate_counterfactual(x, target, counterfactual_data, M, generator)
p3 = plot(ce)
```
\ No newline at end of file
......@@ -59,4 +59,25 @@ function distance_from_energy(
return loss
end
\ No newline at end of file
end
function distance_from_targets(
counterfactual_explanation::AbstractCounterfactualExplanation;
n::Int=100, agg=mean
)
target_samples = counterfactual_explanation.data.X |>
X -> X[:,rand(1:end,n)]
x′ = CounterfactualExplanations.counterfactual(counterfactual_explanation)
loss = map(eachslice(x′, dims=3)) do x
x = Matrix(x)
Δ = map(eachcol(target_samples)) do xsample
norm(x - xsample)
end
return mean(Δ)
end
loss = agg(loss)
return loss
end
using CounterfactualExplanations
using Distributions
using Flux
using JointEnergyModels
(model::AbstractFittedModel)(x) = logits(model, x)
(model::AbstractFittedModel)(x) = log.(CounterfactualExplanations.predict_proba(model, nothing, x))
mutable struct EnergySampler
ce::CounterfactualExplanation
......@@ -28,7 +29,7 @@ function EnergySampler(
# Fit:
i = get_target_index(data.y_levels, ce.target)
buffer = sampler(model.model, opt, (size(data.X, 1), nsamples); niter=niter, y=i)
buffer = sampler(model, opt, (size(data.X, 1), nsamples); niter=niter, y=i)
return EnergySampler(ce, sampler, opt, buffer)
end
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment