From 335f7133f0f0515686120db93f0eb3bcba1c39f9 Mon Sep 17 00:00:00 2001
From: Pat Alt <55311242+pat-alt@users.noreply.github.com>
Date: Thu, 10 Aug 2023 11:52:24 +0200
Subject: [PATCH] :fire:

---
 notebooks/.CondaPkg/env/conda-meta/history |   2 +-
 notebooks/.CondaPkg/meta                   | Bin 753 -> 753 bytes
 notebooks/prototyping.qmd                  | 130 +++++++++++++++++++++
 src/generator.jl                           |   7 +-
 src/penalties.jl                           |  47 ++++++++
 5 files changed, 184 insertions(+), 2 deletions(-)
 create mode 100644 notebooks/prototyping.qmd

diff --git a/notebooks/.CondaPkg/env/conda-meta/history b/notebooks/.CondaPkg/env/conda-meta/history
index dd869616..6c7683b9 100644
--- a/notebooks/.CondaPkg/env/conda-meta/history
+++ b/notebooks/.CondaPkg/env/conda-meta/history
@@ -1,4 +1,4 @@
-==> 2023-08-08 15:03:23 <==
+==> 2023-08-10 11:22:12 <==
 # cmd: /Users/FA31DU/.julia/artifacts/6ecf04294c7f327e02e84972f34835649a5eb35e/bin/micromamba -r /Users/FA31DU/.julia/scratchspaces/0b3b1443-0f03-428d-bdfb-f27f9c1191ea/root create -y -p /Users/FA31DU/code/ECCCo.jl/notebooks/.CondaPkg/env --override-channels --no-channel-priority numpy[version='*'] pip[version='>=22.0.0'] python[version='>=3.7,<4',channel='conda-forge',build='*cpython*'] -c conda-forge
 # conda version: 3.8.0
 +https://conda.anaconda.org/conda-forge/osx-64::xz-5.2.6-h775f41a_0
diff --git a/notebooks/.CondaPkg/meta b/notebooks/.CondaPkg/meta
index 1b11b47e6fbbf80a0ee54e97be26195177f40138..6b7ba5c59ae065f773611edc2fc41659836fb772 100644
GIT binary patch
delta 23
acmey!`jM57hXDd!sz)non{MQL%me^IFa?(Y

delta 23
acmey!`jM57hXDdI8RtHHX0nm*F%tkrhX$<x

diff --git a/notebooks/prototyping.qmd b/notebooks/prototyping.qmd
new file mode 100644
index 00000000..7b61b6f8
--- /dev/null
+++ b/notebooks/prototyping.qmd
@@ -0,0 +1,130 @@
+```{julia}
+include("$(pwd())/notebooks/setup.jl")
+eval(setup_notebooks)
+```
+
+# Linearly Separable Data
+
+```{julia}
+# Hyper:
+_retrain = false
+
+# Data:
+test_size = 0.2
+n_obs = Int(1000 / (1.0 - test_size))
+counterfactual_data, test_data = train_test_split(
+    load_blobs(n_obs; cluster_std=0.1, center_box=(-1. => 1.)); 
+    test_size=test_size
+)
+X, y = CounterfactualExplanations.DataPreprocessing.unpack_data(counterfactual_data)
+X = table(permutedims(X))
+labels = counterfactual_data.output_encoder.labels
+input_dim, n_obs = size(counterfactual_data.X)
+output_dim = length(unique(labels))
+```
+
+First, let's create a couple of image classifier architectures:
+
+```{julia}
+# Model parameters:
+epochs = 100
+batch_size = minimum([Int(round(n_obs/10)), 128])
+n_hidden = 16
+activation = Flux.swish
+builder = MLJFlux.MLP(
+    hidden=(n_hidden, n_hidden, n_hidden), 
+    σ=Flux.swish
+)
+n_ens = 5                                   # number of models in ensemble
+_loss = Flux.Losses.crossentropy       # loss function
+_finaliser = Flux.softmax                   # finaliser function
+```
+
+```{julia}
+# JEM parameters:
+𝒟x = Normal()
+𝒟y = Categorical(ones(output_dim) ./ output_dim)
+sampler = ConditionalSampler(
+    𝒟x, 𝒟y, 
+    input_size=(input_dim,), 
+    batch_size=50,
+)
+α = [1.0,1.0,1e-1]      # penalty strengths
+```
+
+
+```{julia}
+# Joint Energy Model:
+model = JointEnergyClassifier(
+    sampler;
+    builder=builder,
+    epochs=epochs,
+    batch_size=batch_size,
+    finaliser=_finaliser,
+    loss=_loss,
+    jem_training_params=(
+        α=α,verbosity=10,
+    ),
+    sampling_steps=30,
+)
+```
+
+```{julia}
+conf_model = conformal_model(model; method=:simple_inductive, coverage=0.95)
+mach = machine(conf_model, X, labels)
+@info "Begin training model."
+fit!(mach)
+@info "Finished training model."
+M = ECCCo.ConformalModel(mach.model, mach.fitresult)
+```
+
+```{julia}
+λ₁ = 0.25
+λ₂ = 0.75
+λ₃ = 0.75
+Λ = [λ₁, λ₂, λ₃]
+
+opt = Flux.Optimise.Descent(0.01)
+use_class_loss = false
+
+# Benchmark generators:
+generator_dict = Dict(
+    "ECCCo" => ECCCoGenerator(λ=Λ, opt=opt, use_class_loss=use_class_loss),
+    "ECCCo (energy delta)" => ECCCoGenerator(λ=Λ, opt=opt, use_class_loss=use_class_loss, use_energy_delta=true),
+)
+```
+
+```{julia}
+Random.seed!(2023)
+
+X = X isa Matrix ? X : Float32.(permutedims(matrix(X)))
+factual_label =  levels(labels)[1]
+x_factual = reshape(X[:,rand(findall(predict_label(M, counterfactual_data).==factual_label))],input_dim,1)
+target =  levels(labels)[2]
+factual = predict_label(M, counterfactual_data, x_factual)[1]
+
+ces = Dict{Any,Any}()
+plts = []
+for (name, generator) in generator_dict
+    ce = generate_counterfactual(
+        x_factual, target, counterfactual_data, M, generator;
+        initialization=:identity, 
+        converge_when=:generator_conditions,
+    )
+    plt = Plots.plot(
+        ce, title=name, alpha=0.2, 
+        cbar=false, 
+    )
+    if contains(name, "ECCCo")
+        _X = distance_from_energy(ce, return_conditionals=true)
+        Plots.scatter!(
+            _X[1,:],_X[2,:], color=:purple, shape=:star5, 
+            ms=10, label="x̂|$target", alpha=0.5
+        )
+    end
+    push!(plts, plt)
+    ces[name] = ce
+end
+plt = Plots.plot(plts..., size=(800,350))
+display(plt)
+```
\ No newline at end of file
diff --git a/src/generator.jl b/src/generator.jl
index c356fb1a..2119b544 100644
--- a/src/generator.jl
+++ b/src/generator.jl
@@ -25,6 +25,7 @@ function ECCCoGenerator(;
     use_class_loss::Bool=false,
     nsamples::Int=50,
     nmin::Int=25,
+    use_energy_delta::Bool=false,
     kwargs...
 )
 
@@ -47,7 +48,11 @@ function ECCCoGenerator(;
 
     # Energy penalty
     function _energy_penalty(ce::AbstractCounterfactualExplanation)
-        return ECCCo.distance_from_energy(ce; n=nsamples, nmin=nmin, kwargs...)
+        if use_energy_delta
+            return ECCCo.energy_delta(ce; n=nsamples, nmin=nmin, kwargs...)
+        else
+            return ECCCo.distance_from_energy(ce; n=nsamples, nmin=nmin, kwargs...)
+        end
     end
 
     _penalties = [Objectives.distance_l1, _set_size_penalty, _energy_penalty]
diff --git a/src/penalties.jl b/src/penalties.jl
index d8d1b7f6..b42a7071 100644
--- a/src/penalties.jl
+++ b/src/penalties.jl
@@ -38,6 +38,53 @@ function set_size_penalty(
 
 end
 
+function energy_delta(
+    ce::AbstractCounterfactualExplanation;
+    n::Int=50, niter=500, from_buffer=true, agg=mean,
+    choose_lowest_energy=true,
+    choose_random=false,
+    nmin::Int=25,
+    return_conditionals=false,
+    kwargs...
+)
+
+    _loss = 0.0
+    nmin = minimum([nmin, n])
+
+    @assert choose_lowest_energy ⊻ choose_random || !choose_lowest_energy && !choose_random "Must choose either lowest energy or random samples or neither."
+
+    conditional_samples = []
+    ignore_derivatives() do
+        _dict = ce.params
+        if !(:energy_sampler ∈ collect(keys(_dict)))
+            _dict[:energy_sampler] = ECCCo.EnergySampler(ce; niter=niter, nsamples=n, kwargs...)
+        end
+        eng_sampler = _dict[:energy_sampler]
+        if choose_lowest_energy
+            nmin = minimum([nmin, size(eng_sampler.buffer)[end]])
+            xmin = ECCCo.get_lowest_energy_sample(eng_sampler; n=nmin)
+            push!(conditional_samples, xmin)
+        elseif choose_random
+            push!(conditional_samples, rand(eng_sampler, n; from_buffer=from_buffer))
+        else
+            push!(conditional_samples, eng_sampler.buffer)
+        end
+    end
+
+    xtarget = conditional_samples[1]                    # conditional samples
+    x = CounterfactualExplanations.decode_state(ce)     # current state
+    E(x) = -logits(ce.M, x)[ce.target,:]                # negative logits for target class
+    _loss = E(x) .- E(xtarget)
+
+    _loss = reduce((x, y) -> x + y, _loss) / n          # aggregate over samples
+
+    if return_conditionals
+        return conditional_samples[1]
+    end
+    return _loss
+
+end
+
 function distance_from_energy(
     ce::AbstractCounterfactualExplanation;
     n::Int=50, niter=500, from_buffer=true, agg=mean, 
-- 
GitLab