Skip to content
Snippets Groups Projects
Commit a0664452 authored by Pat Alt's avatar Pat Alt
Browse files

uh

parent 76088414
No related branches found
No related tags found
1 merge request!4544 use energy instead of distance
......@@ -473,6 +473,12 @@ savefig(plt, joinpath(output_images_path, "mnist_eccco.png"))
#### Energy Delta (not in paper)
```{julia}
plt, gen_delta, ces = _plot_eccco_mnist(λ = [0.1,0.1,3.0], use_energy_delta=true)
display(plt)
savefig(plt, joinpath(output_images_path, "mnist_eccco_energy_delta.png"))
```
```{julia}
λ_delta = [0.1,0.1,2.5]
λ = [0.1,0.25,0.25]
......@@ -802,6 +808,29 @@ if _regen_all_digits
end
```
#### Energy Delta (not in paper)
```{julia}
_regen_all_digits = true
if _regen_all_digits
function plot_all_digits(rng=123;verbose=true,img_height=180,kwargs...)
plts = []
for i in 0:9
for j in 0:9
@info "Generating counterfactual for $(i) -> $(j)"
plt = plot_mnist(i,j;kwargs...,rng=rng, img_height=img_height)
!verbose || display(plt)
plts = [plts..., plt]
end
end
plt = Plots.plot(plts...; size=(img_height*10,img_height*10), layout=(10,10), dpi=300)
return plt
end
plt = plot_all_digits(generator=gen_delta)
savefig(plt, joinpath(output_images_path, "mnist_eccco_all_digits-delta.png"))
end
```
## Benchmark
```{julia}
......@@ -882,4 +911,4 @@ plt = draw(
)
display(plt)
save(joinpath(output_images_path, "mnist_benchmark.png"), plt, px_per_unit=5)
```
\ No newline at end of file
```
......@@ -72,13 +72,13 @@ function energy_delta(
end
end
xtarget = conditional_samples[1] # conditional samples
x = CounterfactualExplanations.decode_state(ce) # current state
xgenerated = conditional_samples[1] # conditional samples
xproposed = CounterfactualExplanations.decode_state(ce) # current state
t = get_target_index(ce.data.y_levels, ce.target)
E(x) = -logits(ce.M, x)[t,:] # negative logits for target class
_loss = E(x) .- E(xtarget)
E(x) = -logits(ce.M, x)[t,:] # negative logits for target class
_loss = E(xproposed) .- E(xgenerated)
_loss = reduce((x, y) -> x + y, _loss) / n # aggregate over samples
_loss = reduce((x, y) -> x + y, _loss) / n # aggregate over samples
if return_conditionals
return conditional_samples[1]
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment