Skip to content
Snippets Groups Projects
Commit 76088414 authored by Pat Alt's avatar Pat Alt
Browse files

why did my brain decide to come up with this last night and not 3 months ago :fire: :cry:

parent 335f7133
No related branches found
No related tags found
1 merge request!4544 use energy instead of distance
No preview for this file type
No preview for this file type
No preview for this file type
......@@ -164,8 +164,8 @@ end
```{julia}
# Hyper:
_retrain = true
_regen = true
_retrain = false
_regen = false
# Data:
n_obs = 10000
......@@ -357,6 +357,17 @@ model_performance
### Different Models
```{julia}
function plot_mnist(ce; size=(img_height, img_height), kwrgs...)
x = CounterfactualExplanations.counterfactual(ce)
phat = target_probs(ce)
plt = Plots.plot(
convert2image(MNIST, reshape(x,28,28));
axis=([], false),
size=size,
kwrgs...,
)
end
function _plot_eccco_mnist(
x::Union{AbstractArray, Int}=x_factual, target::Int=target;
λ=[0.1,0.25,0.25],
......@@ -372,6 +383,7 @@ function _plot_eccco_mnist(
plot_factual::Bool = false,
generator::Union{Nothing,CounterfactualExplanations.AbstractGenerator}=nothing,
test_data::Bool = false,
use_energy_delta::Bool = false,
kwrgs...,
)
......@@ -392,6 +404,7 @@ function _plot_eccco_mnist(
use_class_loss=use_class_loss,
nsamples=10,
nmin=10,
use_energy_delta=use_energy_delta,
)
end
......@@ -458,6 +471,21 @@ display(plt)
savefig(plt, joinpath(output_images_path, "mnist_eccco.png"))
```
#### Energy Delta (not in paper)
```{julia}
λ_delta = [0.1,0.1,2.5]
λ = [0.1,0.25,0.25]
plts = []
for i in 0:9
plt, _, _ = _plot_eccco_mnist(x_factual, i; λ = λ, plot_title="Distance")
plt_delta, _, _ = _plot_eccco_mnist(x_factual, i; λ = λ_delta, use_energy_delta=true, plot_title="Energy Delta")
plt = Plots.plot(plt, plt_delta; size=(img_height*2,img_height), layout=(1,2))
display(plt)
push!(plts, plt)
end
```
#### Additional Models (not in paper)
LeNet-5:
......
using ChainRules: ignore_derivatives
using CounterfactualExplanations: get_target_index
using Distances
using Flux
using LinearAlgebra: norm
......@@ -73,7 +74,8 @@ function energy_delta(
xtarget = conditional_samples[1] # conditional samples
x = CounterfactualExplanations.decode_state(ce) # current state
E(x) = -logits(ce.M, x)[ce.target,:] # negative logits for target class
t = get_target_index(ce.data.y_levels, ce.target)
E(x) = -logits(ce.M, x)[t,:] # negative logits for target class
_loss = E(x) .- E(xtarget)
_loss = reduce((x, y) -> x + y, _loss) / n # aggregate over samples
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment