Skip to content
Snippets Groups Projects
Commit 335f7133 authored by Pat Alt's avatar Pat Alt
Browse files

:fire:

parent 23d21265
No related branches found
No related tags found
1 merge request!4544 use energy instead of distance
==> 2023-08-08 15:03:23 <==
==> 2023-08-10 11:22:12 <==
# cmd: /Users/FA31DU/.julia/artifacts/6ecf04294c7f327e02e84972f34835649a5eb35e/bin/micromamba -r /Users/FA31DU/.julia/scratchspaces/0b3b1443-0f03-428d-bdfb-f27f9c1191ea/root create -y -p /Users/FA31DU/code/ECCCo.jl/notebooks/.CondaPkg/env --override-channels --no-channel-priority numpy[version='*'] pip[version='>=22.0.0'] python[version='>=3.7,<4',channel='conda-forge',build='*cpython*'] -c conda-forge
# conda version: 3.8.0
+https://conda.anaconda.org/conda-forge/osx-64::xz-5.2.6-h775f41a_0
......
No preview for this file type
```{julia}
include("$(pwd())/notebooks/setup.jl")
eval(setup_notebooks)
```
# Linearly Separable Data
```{julia}
# Hyper:
_retrain = false
# Data:
test_size = 0.2
n_obs = Int(1000 / (1.0 - test_size))
counterfactual_data, test_data = train_test_split(
load_blobs(n_obs; cluster_std=0.1, center_box=(-1. => 1.));
test_size=test_size
)
X, y = CounterfactualExplanations.DataPreprocessing.unpack_data(counterfactual_data)
X = table(permutedims(X))
labels = counterfactual_data.output_encoder.labels
input_dim, n_obs = size(counterfactual_data.X)
output_dim = length(unique(labels))
```
First, let's create a couple of image classifier architectures:
```{julia}
# Model parameters:
epochs = 100
batch_size = minimum([Int(round(n_obs/10)), 128])
n_hidden = 16
activation = Flux.swish
builder = MLJFlux.MLP(
hidden=(n_hidden, n_hidden, n_hidden),
σ=Flux.swish
)
n_ens = 5 # number of models in ensemble
_loss = Flux.Losses.crossentropy # loss function
_finaliser = Flux.softmax # finaliser function
```
```{julia}
# JEM parameters:
𝒟x = Normal()
𝒟y = Categorical(ones(output_dim) ./ output_dim)
sampler = ConditionalSampler(
𝒟x, 𝒟y,
input_size=(input_dim,),
batch_size=50,
)
α = [1.0,1.0,1e-1] # penalty strengths
```
```{julia}
# Joint Energy Model:
model = JointEnergyClassifier(
sampler;
builder=builder,
epochs=epochs,
batch_size=batch_size,
finaliser=_finaliser,
loss=_loss,
jem_training_params=(
α=α,verbosity=10,
),
sampling_steps=30,
)
```
```{julia}
conf_model = conformal_model(model; method=:simple_inductive, coverage=0.95)
mach = machine(conf_model, X, labels)
@info "Begin training model."
fit!(mach)
@info "Finished training model."
M = ECCCo.ConformalModel(mach.model, mach.fitresult)
```
```{julia}
λ₁ = 0.25
λ₂ = 0.75
λ₃ = 0.75
Λ = [λ₁, λ₂, λ₃]
opt = Flux.Optimise.Descent(0.01)
use_class_loss = false
# Benchmark generators:
generator_dict = Dict(
"ECCCo" => ECCCoGenerator(λ=Λ, opt=opt, use_class_loss=use_class_loss),
"ECCCo (energy delta)" => ECCCoGenerator(λ=Λ, opt=opt, use_class_loss=use_class_loss, use_energy_delta=true),
)
```
```{julia}
Random.seed!(2023)
X = X isa Matrix ? X : Float32.(permutedims(matrix(X)))
factual_label = levels(labels)[1]
x_factual = reshape(X[:,rand(findall(predict_label(M, counterfactual_data).==factual_label))],input_dim,1)
target = levels(labels)[2]
factual = predict_label(M, counterfactual_data, x_factual)[1]
ces = Dict{Any,Any}()
plts = []
for (name, generator) in generator_dict
ce = generate_counterfactual(
x_factual, target, counterfactual_data, M, generator;
initialization=:identity,
converge_when=:generator_conditions,
)
plt = Plots.plot(
ce, title=name, alpha=0.2,
cbar=false,
)
if contains(name, "ECCCo")
_X = distance_from_energy(ce, return_conditionals=true)
Plots.scatter!(
_X[1,:],_X[2,:], color=:purple, shape=:star5,
ms=10, label="x̂|$target", alpha=0.5
)
end
push!(plts, plt)
ces[name] = ce
end
plt = Plots.plot(plts..., size=(800,350))
display(plt)
```
\ No newline at end of file
......@@ -25,6 +25,7 @@ function ECCCoGenerator(;
use_class_loss::Bool=false,
nsamples::Int=50,
nmin::Int=25,
use_energy_delta::Bool=false,
kwargs...
)
......@@ -47,7 +48,11 @@ function ECCCoGenerator(;
# Energy penalty
function _energy_penalty(ce::AbstractCounterfactualExplanation)
return ECCCo.distance_from_energy(ce; n=nsamples, nmin=nmin, kwargs...)
if use_energy_delta
return ECCCo.energy_delta(ce; n=nsamples, nmin=nmin, kwargs...)
else
return ECCCo.distance_from_energy(ce; n=nsamples, nmin=nmin, kwargs...)
end
end
_penalties = [Objectives.distance_l1, _set_size_penalty, _energy_penalty]
......
......@@ -38,6 +38,53 @@ function set_size_penalty(
end
function energy_delta(
ce::AbstractCounterfactualExplanation;
n::Int=50, niter=500, from_buffer=true, agg=mean,
choose_lowest_energy=true,
choose_random=false,
nmin::Int=25,
return_conditionals=false,
kwargs...
)
_loss = 0.0
nmin = minimum([nmin, n])
@assert choose_lowest_energy choose_random || !choose_lowest_energy && !choose_random "Must choose either lowest energy or random samples or neither."
conditional_samples = []
ignore_derivatives() do
_dict = ce.params
if !(:energy_sampler collect(keys(_dict)))
_dict[:energy_sampler] = ECCCo.EnergySampler(ce; niter=niter, nsamples=n, kwargs...)
end
eng_sampler = _dict[:energy_sampler]
if choose_lowest_energy
nmin = minimum([nmin, size(eng_sampler.buffer)[end]])
xmin = ECCCo.get_lowest_energy_sample(eng_sampler; n=nmin)
push!(conditional_samples, xmin)
elseif choose_random
push!(conditional_samples, rand(eng_sampler, n; from_buffer=from_buffer))
else
push!(conditional_samples, eng_sampler.buffer)
end
end
xtarget = conditional_samples[1] # conditional samples
x = CounterfactualExplanations.decode_state(ce) # current state
E(x) = -logits(ce.M, x)[ce.target,:] # negative logits for target class
_loss = E(x) .- E(xtarget)
_loss = reduce((x, y) -> x + y, _loss) / n # aggregate over samples
if return_conditionals
return conditional_samples[1]
end
return _loss
end
function distance_from_energy(
ce::AbstractCounterfactualExplanation;
n::Int=50, niter=500, from_buffer=true, agg=mean,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment