Skip to content
Snippets Groups Projects
Commit 72e17418 authored by Pat Alt's avatar Pat Alt
Browse files

need to regularize energy delta

parent 828bcdf2
No related branches found
No related tags found
1 merge request!7669 initial run including fmnist lenet and new method
...@@ -31,7 +31,7 @@ Builds a dictionary of default models for training. ...@@ -31,7 +31,7 @@ Builds a dictionary of default models for training.
function default_models(; function default_models(;
sampler::AbstractSampler, sampler::AbstractSampler,
builder::MLJFlux.Builder=default_builder(), builder::MLJFlux.Builder=default_builder(),
epochs::Int=25, epochs::Int=100,
batch_size::Int=128, batch_size::Int=128,
finaliser::Function=Flux.softmax, finaliser::Function=Flux.softmax,
loss::Function=Flux.Losses.crossentropy, loss::Function=Flux.Losses.crossentropy,
......
```{julia} ```{julia}
include("$(pwd())/notebooks/setup.jl") include("$(pwd())/experiments/setup_env.jl")
eval(setup_notebooks)
``` ```
# Linearly Separable Data # Linearly Separable Data
```{julia} ```{julia}
# Hyper: dataname = "linearly_separable"
_retrain = false outcome = Serialization.deserialize(joinpath(DEFAULT_OUTPUT_PATH, "$(dataname)_outcome.jls"))
# Data: # Unpack
test_size = 0.2 exp = outcome.exp
n_obs = Int(1000 / (1.0 - test_size)) model_dict = outcome.model_dict
counterfactual_data, test_data = train_test_split( generator_dict = outcome.generator_dict
load_blobs(n_obs; cluster_std=0.1, center_box=(-1. => 1.)); bmk = outcome.bmk
test_size=test_size
)
X, y = CounterfactualExplanations.DataPreprocessing.unpack_data(counterfactual_data)
X = table(permutedims(X))
labels = counterfactual_data.output_encoder.labels
input_dim, n_obs = size(counterfactual_data.X)
output_dim = length(unique(labels))
```
First, let's create a couple of image classifier architectures:
```{julia}
# Model parameters:
epochs = 100
bs = minimum([Int(round(n_obs/10)), 128])
n_hidden = 16
activation = Flux.swish
builder = MLJFlux.MLP(
hidden=(n_hidden, n_hidden, n_hidden),
σ=Flux.swish
)
n_ens = 5 # number of models in ensemble
_loss = Flux.Losses.crossentropy # loss function
_finaliser = Flux.softmax # finaliser function
```
```{julia}
# JEM parameters:
𝒟x = Normal()
𝒟y = Categorical(ones(output_dim) ./ output_dim)
sampler = ConditionalSampler(
𝒟x, 𝒟y,
input_size=(input_dim,),
batch_size=50,
)
α = [1.0,1.0,1e-1] # penalty strengths
```
```{julia}
# Joint Energy Model:
model = JointEnergyClassifier(
sampler;
builder=builder,
epochs=epochs,
batch_size=bs,
finaliser=_finaliser,
loss=_loss,
jem_training_params=(
α=α,verbosity=10,
),
sampling_steps=30,
)
```
```{julia}
conf_model = conformal_model(model; method=:simple_inductive, coverage=0.95)
mach = machine(conf_model, X, labels)
@info "Begin training model."
fit!(mach)
@info "Finished training model."
M = ECCCo.ConformalModel(mach.model, mach.fitresult)
```
```{julia}
λ₁ = 0.25
λ₂ = 0.75
λ₃ = 0.75
Λ = [λ₁, λ₂, λ₃]
opt = Flux.Optimise.Descent(0.01)
use_class_loss = false
# Benchmark generators:
generator_dict = Dict(
"ECCCo" => ECCCoGenerator(λ=Λ, opt=opt, use_class_loss=use_class_loss),
"ECCCo (energy delta)" => ECCCoGenerator(λ=Λ, opt=opt, use_class_loss=use_class_loss, use_energy_delta=true),
)
``` ```
```{julia} ```{julia}
Random.seed!(2023) Random.seed!(2023)
# Unpack
counterfactual_data = exp.counterfactual_data
X, labels = counterfactual_data.X, counterfactual_data.output_encoder.labels
M = model_dict["MLP"]
gen = filter(((k,v),) -> k in ["ECCCo", "ECCCo-Δ"], generator_dict)
# Prepare search:
X = X isa Matrix ? X : Float32.(permutedims(matrix(X))) X = X isa Matrix ? X : Float32.(permutedims(matrix(X)))
factual_label = levels(labels)[2] factual_label = levels(labels)[2]
x_factual = reshape(X[:,rand(findall(predict_label(M, counterfactual_data).==factual_label))],input_dim,1) x_factual = X[:,rand(findall(predict_label(M, counterfactual_data).==factual_label))] |>
x -> x[:,:]
target = levels(labels)[1] target = levels(labels)[1]
factual = predict_label(M, counterfactual_data, x_factual)[1] factual = predict_label(M, counterfactual_data, x_factual)[1]
ces = Dict{Any,Any}() ces = Dict{Any,Any}()
plts = [] plts = []
for (name, generator) in generator_dict for (name, generator) in gen
ce = generate_counterfactual( ce = generate_counterfactual(
x_factual, target, counterfactual_data, M, generator; x_factual, target, counterfactual_data, M, generator;
initialization=:identity, initialization=:identity,
......
...@@ -76,14 +76,19 @@ function energy_delta( ...@@ -76,14 +76,19 @@ function energy_delta(
xproposed = CounterfactualExplanations.decode_state(ce) # current state xproposed = CounterfactualExplanations.decode_state(ce) # current state
t = get_target_index(ce.data.y_levels, ce.target) t = get_target_index(ce.data.y_levels, ce.target)
E(x) = -logits(ce.M, x)[t,:] # negative logits for target class E(x) = -logits(ce.M, x)[t,:] # negative logits for target class
_loss = E(xproposed) .- E(xgenerated)
_loss = reduce((x, y) -> x + y, _loss) / n # aggregate over samples # Generative loss:
gen_loss = E(xproposed) .- E(xgenerated)
gen_loss = reduce((x, y) -> x + y, gen_loss) / n # aggregate over samples
# Regularization loss:
reg_loss = E(xgenerated).^2 .+ E(xproposed).^2
reg_loss = reduce((x, y) -> x + y, reg_loss) / n # aggregate over samples
if return_conditionals if return_conditionals
return conditional_samples[1] return conditional_samples[1]
end end
return _loss return gen_loss + 0.1reg_loss
end end
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment