Skip to content
Snippets Groups Projects
Commit f4692408 authored by pat-alt's avatar pat-alt
Browse files

differentiable penalties for fidelity and plausibility

parent b57c969f
No related branches found
No related tags found
No related merge requests found
...@@ -15,10 +15,38 @@ using Plots ...@@ -15,10 +15,38 @@ using Plots
# Fidelity Measures # Fidelity Measures
## Binary
```{julia}
# Setup
counterfactual_data = load_linearly_separable()
M = fit_model(counterfactual_data, :DeepEnsemble)
target = 2
factual = 1
chosen = rand(findall(predict_label(M, counterfactual_data) .== factual))
x = select_factual(counterfactual_data, chosen)
# Search:
generator = GenericGenerator(opt=Descent(0.01))
ce = generate_counterfactual(x, target, counterfactual_data, M, generator)
```
```{julia}
niter = 100
nsamples = 100
sampler = CCE.EnergySampler(ce;niter=niter, nsamples=100)
Xgen = rand(sampler, nsamples)
plt = plot(M, counterfactual_data, target=ce.target, xlims=(-5,5),ylims=(-5,5),cbar=false)
scatter!(Xgen[1,:],Xgen[2,:],alpha=0.5,color=target,shape=:star,label="X|y=$target")
```
## Multi-Class
```{julia} ```{julia}
# Setup # Setup
counterfactual_data = load_multi_class() counterfactual_data = load_multi_class()
M = fit_model(counterfactual_data, :MLP) M = fit_model(counterfactual_data, :DeepEnsemble)
target = 4 target = 4
factual = 2 factual = 2
chosen = rand(findall(predict_label(M, counterfactual_data) .== factual)) chosen = rand(findall(predict_label(M, counterfactual_data) .== factual))
...@@ -45,7 +73,19 @@ p1 = plot(ce) ...@@ -45,7 +73,19 @@ p1 = plot(ce)
```{julia} ```{julia}
@objective(generator, _ + 0.1distance_l2 + 100.0distance_from_energy) using CCE: distance_from_energy
@objective(generator, _ + 0.1distance_l2 + 10.0distance_from_energy)
ce = generate_counterfactual(x, target, counterfactual_data, M, generator) ce = generate_counterfactual(x, target, counterfactual_data, M, generator)
p2 = plot(ce) p2 = plot(ce)
```
```{julia}
using CCE: distance_from_targets
@objective(
generator,
_ + 0.1distance_l2 + 1.0distance_from_energy + 10.0distance_from_targets
)
ce = generate_counterfactual(x, target, counterfactual_data, M, generator)
p3 = plot(ce)
``` ```
\ No newline at end of file
...@@ -59,4 +59,25 @@ function distance_from_energy( ...@@ -59,4 +59,25 @@ function distance_from_energy(
return loss return loss
end end
\ No newline at end of file
function distance_from_targets(
counterfactual_explanation::AbstractCounterfactualExplanation;
n::Int=100, agg=mean
)
target_samples = counterfactual_explanation.data.X |>
X -> X[:,rand(1:end,n)]
x′ = CounterfactualExplanations.counterfactual(counterfactual_explanation)
loss = map(eachslice(x′, dims=3)) do x
x = Matrix(x)
Δ = map(eachcol(target_samples)) do xsample
norm(x - xsample)
end
return mean(Δ)
end
loss = agg(loss)
return loss
end
using CounterfactualExplanations using CounterfactualExplanations
using Distributions using Distributions
using Flux
using JointEnergyModels using JointEnergyModels
(model::AbstractFittedModel)(x) = logits(model, x) (model::AbstractFittedModel)(x) = log.(CounterfactualExplanations.predict_proba(model, nothing, x))
mutable struct EnergySampler mutable struct EnergySampler
ce::CounterfactualExplanation ce::CounterfactualExplanation
...@@ -28,7 +29,7 @@ function EnergySampler( ...@@ -28,7 +29,7 @@ function EnergySampler(
# Fit: # Fit:
i = get_target_index(data.y_levels, ce.target) i = get_target_index(data.y_levels, ce.target)
buffer = sampler(model.model, opt, (size(data.X, 1), nsamples); niter=niter, y=i) buffer = sampler(model, opt, (size(data.X, 1), nsamples); niter=niter, y=i)
return EnergySampler(ce, sampler, opt, buffer) return EnergySampler(ce, sampler, opt, buffer)
end end
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment