Skip to content
Snippets Groups Projects
Unverified Commit 6c11c5c8 authored by Pat Alt's avatar Pat Alt Committed by GitHub
Browse files

Merge pull request #45 from pat-alt/44-use-energy-instead-of-distance

44 use energy instead of distance
parents 23d21265 a0664452
No related branches found
No related tags found
No related merge requests found
No preview for this file type
No preview for this file type
No preview for this file type
==> 2023-08-08 15:03:23 <== ==> 2023-08-10 11:22:12 <==
# cmd: /Users/FA31DU/.julia/artifacts/6ecf04294c7f327e02e84972f34835649a5eb35e/bin/micromamba -r /Users/FA31DU/.julia/scratchspaces/0b3b1443-0f03-428d-bdfb-f27f9c1191ea/root create -y -p /Users/FA31DU/code/ECCCo.jl/notebooks/.CondaPkg/env --override-channels --no-channel-priority numpy[version='*'] pip[version='>=22.0.0'] python[version='>=3.7,<4',channel='conda-forge',build='*cpython*'] -c conda-forge # cmd: /Users/FA31DU/.julia/artifacts/6ecf04294c7f327e02e84972f34835649a5eb35e/bin/micromamba -r /Users/FA31DU/.julia/scratchspaces/0b3b1443-0f03-428d-bdfb-f27f9c1191ea/root create -y -p /Users/FA31DU/code/ECCCo.jl/notebooks/.CondaPkg/env --override-channels --no-channel-priority numpy[version='*'] pip[version='>=22.0.0'] python[version='>=3.7,<4',channel='conda-forge',build='*cpython*'] -c conda-forge
# conda version: 3.8.0 # conda version: 3.8.0
+https://conda.anaconda.org/conda-forge/osx-64::xz-5.2.6-h775f41a_0 +https://conda.anaconda.org/conda-forge/osx-64::xz-5.2.6-h775f41a_0
......
No preview for this file type
...@@ -164,8 +164,8 @@ end ...@@ -164,8 +164,8 @@ end
```{julia} ```{julia}
# Hyper: # Hyper:
_retrain = true _retrain = false
_regen = true _regen = false
# Data: # Data:
n_obs = 10000 n_obs = 10000
...@@ -357,6 +357,17 @@ model_performance ...@@ -357,6 +357,17 @@ model_performance
### Different Models ### Different Models
```{julia} ```{julia}
function plot_mnist(ce; size=(img_height, img_height), kwrgs...)
x = CounterfactualExplanations.counterfactual(ce)
phat = target_probs(ce)
plt = Plots.plot(
convert2image(MNIST, reshape(x,28,28));
axis=([], false),
size=size,
kwrgs...,
)
end
function _plot_eccco_mnist( function _plot_eccco_mnist(
x::Union{AbstractArray, Int}=x_factual, target::Int=target; x::Union{AbstractArray, Int}=x_factual, target::Int=target;
λ=[0.1,0.25,0.25], λ=[0.1,0.25,0.25],
...@@ -372,6 +383,7 @@ function _plot_eccco_mnist( ...@@ -372,6 +383,7 @@ function _plot_eccco_mnist(
plot_factual::Bool = false, plot_factual::Bool = false,
generator::Union{Nothing,CounterfactualExplanations.AbstractGenerator}=nothing, generator::Union{Nothing,CounterfactualExplanations.AbstractGenerator}=nothing,
test_data::Bool = false, test_data::Bool = false,
use_energy_delta::Bool = false,
kwrgs..., kwrgs...,
) )
...@@ -392,6 +404,7 @@ function _plot_eccco_mnist( ...@@ -392,6 +404,7 @@ function _plot_eccco_mnist(
use_class_loss=use_class_loss, use_class_loss=use_class_loss,
nsamples=10, nsamples=10,
nmin=10, nmin=10,
use_energy_delta=use_energy_delta,
) )
end end
...@@ -458,6 +471,27 @@ display(plt) ...@@ -458,6 +471,27 @@ display(plt)
savefig(plt, joinpath(output_images_path, "mnist_eccco.png")) savefig(plt, joinpath(output_images_path, "mnist_eccco.png"))
``` ```
#### Energy Delta (not in paper)
```{julia}
plt, gen_delta, ces = _plot_eccco_mnist(λ = [0.1,0.1,3.0], use_energy_delta=true)
display(plt)
savefig(plt, joinpath(output_images_path, "mnist_eccco_energy_delta.png"))
```
```{julia}
λ_delta = [0.1,0.1,2.5]
λ = [0.1,0.25,0.25]
plts = []
for i in 0:9
plt, _, _ = _plot_eccco_mnist(x_factual, i; λ = λ, plot_title="Distance")
plt_delta, _, _ = _plot_eccco_mnist(x_factual, i; λ = λ_delta, use_energy_delta=true, plot_title="Energy Delta")
plt = Plots.plot(plt, plt_delta; size=(img_height*2,img_height), layout=(1,2))
display(plt)
push!(plts, plt)
end
```
#### Additional Models (not in paper) #### Additional Models (not in paper)
LeNet-5: LeNet-5:
...@@ -774,6 +808,29 @@ if _regen_all_digits ...@@ -774,6 +808,29 @@ if _regen_all_digits
end end
``` ```
#### Energy Delta (not in paper)
```{julia}
_regen_all_digits = true
if _regen_all_digits
function plot_all_digits(rng=123;verbose=true,img_height=180,kwargs...)
plts = []
for i in 0:9
for j in 0:9
@info "Generating counterfactual for $(i) -> $(j)"
plt = plot_mnist(i,j;kwargs...,rng=rng, img_height=img_height)
!verbose || display(plt)
plts = [plts..., plt]
end
end
plt = Plots.plot(plts...; size=(img_height*10,img_height*10), layout=(10,10), dpi=300)
return plt
end
plt = plot_all_digits(generator=gen_delta)
savefig(plt, joinpath(output_images_path, "mnist_eccco_all_digits-delta.png"))
end
```
## Benchmark ## Benchmark
```{julia} ```{julia}
...@@ -854,4 +911,4 @@ plt = draw( ...@@ -854,4 +911,4 @@ plt = draw(
) )
display(plt) display(plt)
save(joinpath(output_images_path, "mnist_benchmark.png"), plt, px_per_unit=5) save(joinpath(output_images_path, "mnist_benchmark.png"), plt, px_per_unit=5)
``` ```
\ No newline at end of file
```{julia}
include("$(pwd())/notebooks/setup.jl")
eval(setup_notebooks)
```
# Linearly Separable Data
```{julia}
# Hyper:
_retrain = false
# Data:
test_size = 0.2
n_obs = Int(1000 / (1.0 - test_size))
counterfactual_data, test_data = train_test_split(
load_blobs(n_obs; cluster_std=0.1, center_box=(-1. => 1.));
test_size=test_size
)
X, y = CounterfactualExplanations.DataPreprocessing.unpack_data(counterfactual_data)
X = table(permutedims(X))
labels = counterfactual_data.output_encoder.labels
input_dim, n_obs = size(counterfactual_data.X)
output_dim = length(unique(labels))
```
First, let's create a couple of image classifier architectures:
```{julia}
# Model parameters:
epochs = 100
batch_size = minimum([Int(round(n_obs/10)), 128])
n_hidden = 16
activation = Flux.swish
builder = MLJFlux.MLP(
hidden=(n_hidden, n_hidden, n_hidden),
σ=Flux.swish
)
n_ens = 5 # number of models in ensemble
_loss = Flux.Losses.crossentropy # loss function
_finaliser = Flux.softmax # finaliser function
```
```{julia}
# JEM parameters:
𝒟x = Normal()
𝒟y = Categorical(ones(output_dim) ./ output_dim)
sampler = ConditionalSampler(
𝒟x, 𝒟y,
input_size=(input_dim,),
batch_size=50,
)
α = [1.0,1.0,1e-1] # penalty strengths
```
```{julia}
# Joint Energy Model:
model = JointEnergyClassifier(
sampler;
builder=builder,
epochs=epochs,
batch_size=batch_size,
finaliser=_finaliser,
loss=_loss,
jem_training_params=(
α=α,verbosity=10,
),
sampling_steps=30,
)
```
```{julia}
conf_model = conformal_model(model; method=:simple_inductive, coverage=0.95)
mach = machine(conf_model, X, labels)
@info "Begin training model."
fit!(mach)
@info "Finished training model."
M = ECCCo.ConformalModel(mach.model, mach.fitresult)
```
```{julia}
λ₁ = 0.25
λ₂ = 0.75
λ₃ = 0.75
Λ = [λ₁, λ₂, λ₃]
opt = Flux.Optimise.Descent(0.01)
use_class_loss = false
# Benchmark generators:
generator_dict = Dict(
"ECCCo" => ECCCoGenerator(λ=Λ, opt=opt, use_class_loss=use_class_loss),
"ECCCo (energy delta)" => ECCCoGenerator(λ=Λ, opt=opt, use_class_loss=use_class_loss, use_energy_delta=true),
)
```
```{julia}
Random.seed!(2023)
X = X isa Matrix ? X : Float32.(permutedims(matrix(X)))
factual_label = levels(labels)[1]
x_factual = reshape(X[:,rand(findall(predict_label(M, counterfactual_data).==factual_label))],input_dim,1)
target = levels(labels)[2]
factual = predict_label(M, counterfactual_data, x_factual)[1]
ces = Dict{Any,Any}()
plts = []
for (name, generator) in generator_dict
ce = generate_counterfactual(
x_factual, target, counterfactual_data, M, generator;
initialization=:identity,
converge_when=:generator_conditions,
)
plt = Plots.plot(
ce, title=name, alpha=0.2,
cbar=false,
)
if contains(name, "ECCCo")
_X = distance_from_energy(ce, return_conditionals=true)
Plots.scatter!(
_X[1,:],_X[2,:], color=:purple, shape=:star5,
ms=10, label="x̂|$target", alpha=0.5
)
end
push!(plts, plt)
ces[name] = ce
end
plt = Plots.plot(plts..., size=(800,350))
display(plt)
```
\ No newline at end of file
...@@ -25,6 +25,7 @@ function ECCCoGenerator(; ...@@ -25,6 +25,7 @@ function ECCCoGenerator(;
use_class_loss::Bool=false, use_class_loss::Bool=false,
nsamples::Int=50, nsamples::Int=50,
nmin::Int=25, nmin::Int=25,
use_energy_delta::Bool=false,
kwargs... kwargs...
) )
...@@ -47,7 +48,11 @@ function ECCCoGenerator(; ...@@ -47,7 +48,11 @@ function ECCCoGenerator(;
# Energy penalty # Energy penalty
function _energy_penalty(ce::AbstractCounterfactualExplanation) function _energy_penalty(ce::AbstractCounterfactualExplanation)
return ECCCo.distance_from_energy(ce; n=nsamples, nmin=nmin, kwargs...) if use_energy_delta
return ECCCo.energy_delta(ce; n=nsamples, nmin=nmin, kwargs...)
else
return ECCCo.distance_from_energy(ce; n=nsamples, nmin=nmin, kwargs...)
end
end end
_penalties = [Objectives.distance_l1, _set_size_penalty, _energy_penalty] _penalties = [Objectives.distance_l1, _set_size_penalty, _energy_penalty]
......
using ChainRules: ignore_derivatives using ChainRules: ignore_derivatives
using CounterfactualExplanations: get_target_index
using Distances using Distances
using Flux using Flux
using LinearAlgebra: norm using LinearAlgebra: norm
...@@ -38,6 +39,54 @@ function set_size_penalty( ...@@ -38,6 +39,54 @@ function set_size_penalty(
end end
function energy_delta(
ce::AbstractCounterfactualExplanation;
n::Int=50, niter=500, from_buffer=true, agg=mean,
choose_lowest_energy=true,
choose_random=false,
nmin::Int=25,
return_conditionals=false,
kwargs...
)
_loss = 0.0
nmin = minimum([nmin, n])
@assert choose_lowest_energy choose_random || !choose_lowest_energy && !choose_random "Must choose either lowest energy or random samples or neither."
conditional_samples = []
ignore_derivatives() do
_dict = ce.params
if !(:energy_sampler collect(keys(_dict)))
_dict[:energy_sampler] = ECCCo.EnergySampler(ce; niter=niter, nsamples=n, kwargs...)
end
eng_sampler = _dict[:energy_sampler]
if choose_lowest_energy
nmin = minimum([nmin, size(eng_sampler.buffer)[end]])
xmin = ECCCo.get_lowest_energy_sample(eng_sampler; n=nmin)
push!(conditional_samples, xmin)
elseif choose_random
push!(conditional_samples, rand(eng_sampler, n; from_buffer=from_buffer))
else
push!(conditional_samples, eng_sampler.buffer)
end
end
xgenerated = conditional_samples[1] # conditional samples
xproposed = CounterfactualExplanations.decode_state(ce) # current state
t = get_target_index(ce.data.y_levels, ce.target)
E(x) = -logits(ce.M, x)[t,:] # negative logits for target class
_loss = E(xproposed) .- E(xgenerated)
_loss = reduce((x, y) -> x + y, _loss) / n # aggregate over samples
if return_conditionals
return conditional_samples[1]
end
return _loss
end
function distance_from_energy( function distance_from_energy(
ce::AbstractCounterfactualExplanation; ce::AbstractCounterfactualExplanation;
n::Int=50, niter=500, from_buffer=true, agg=mean, n::Int=50, niter=500, from_buffer=true, agg=mean,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment