From b57c969f92724510725a7f345ac24c126613d726 Mon Sep 17 00:00:00 2001
From: pat-alt <altmeyerpat@gmail.com>
Date: Fri, 17 Mar 2023 17:10:53 +0100
Subject: [PATCH] penalty working

---
 notebooks/fidelity.qmd | 37 ++++++++++++++++++++++---------------
 src/penalties.jl       | 29 ++++++++++++++++++++++++-----
 src/sampling.jl        |  2 +-
 3 files changed, 47 insertions(+), 21 deletions(-)

diff --git a/notebooks/fidelity.qmd b/notebooks/fidelity.qmd
index 2b7e384a..def8c275 100644
--- a/notebooks/fidelity.qmd
+++ b/notebooks/fidelity.qmd
@@ -20,25 +20,32 @@ using Plots
 counterfactual_data = load_multi_class()
 M = fit_model(counterfactual_data, :MLP)
 target = 4
-factual = 1
+factual = 2
 chosen = rand(findall(predict_label(M, counterfactual_data) .== factual))
 x = select_factual(counterfactual_data, chosen)
-```
 
+# Search:
+generator = GenericGenerator(opt=Descent(0.01))
+ce = generate_counterfactual(x, target, counterfactual_data, M, generator)
+```
 
 ```{julia}
-niter = 10000
+niter = 100
 nsamples = 100
-plts = []
-for target in ce.data.y_levels
-    # Search:
-    generator = GenericGenerator()
-    ce = generate_counterfactual(x, target, counterfactual_data, M, generator)
-    sampler = CCE.EnergySampler(ce;niter=niter, nsamples=100)
-    Xgen = rand(sampler, nsamples)
-    plt = plot(M, counterfactual_data, target=ce.target, xlims=(-5,5),ylims=(-5,5),cbar=false)
-    scatter!(Xgen[1,:],Xgen[2,:],alpha=0.5,color=target,shape=:star,label="X|y=$target")
-    push!(plts, plt)
-end
-plot(plts..., layout=(1,length(ce.data.y_levels)), size=(length(ce.data.y_levels)*300,300))
+
+sampler = CCE.EnergySampler(ce;niter=niter, nsamples=100)
+Xgen = rand(sampler, nsamples)
+plt = plot(M, counterfactual_data, target=ce.target, xlims=(-5,5),ylims=(-5,5),cbar=false)
+scatter!(Xgen[1,:],Xgen[2,:],alpha=0.5,color=target,shape=:star,label="X|y=$target")
+```
+
+```{julia}
+p1 = plot(ce)
+```
+
+
+```{julia}
+@objective(generator, _ + 0.1distance_l2 + 100.0distance_from_energy)
+ce = generate_counterfactual(x, target, counterfactual_data, M, generator)
+p2 = plot(ce)
 ```
\ No newline at end of file
diff --git a/src/penalties.jl b/src/penalties.jl
index 611fb3e7..a3f35a9b 100644
--- a/src/penalties.jl
+++ b/src/penalties.jl
@@ -1,3 +1,5 @@
+using ChainRules: ignore_derivatives
+using LinearAlgebra: norm
 using Statistics: mean
 
 """
@@ -33,11 +35,28 @@ function set_size_penalty(
 end
 
 function distance_from_energy(
-    counterfactual_explanation::AbstractCounterfactualExplanation; 
-    n::Int=100, retrain=false, kwargs...
+    counterfactual_explanation::AbstractCounterfactualExplanation;
+    n::Int=100, retrain=false, agg=mean, kwargs...
 )
-    sampler = get!(counterfactual_explanation.params, :energy_sampler) do 
-        CCE.EnergySampler(counterfactual_explanation; kwargs...)
+    conditional_samples = []
+    ignore_derivatives() do
+        _dict = counterfactual_explanation.params
+        if !(:energy_sampler ∈ collect(keys(_dict)))
+            _dict[:energy_sampler] = CCE.EnergySampler(counterfactual_explanation; kwargs...)
+        end
+        sampler = _dict[:energy_sampler]
+        push!(conditional_samples, rand(sampler, n; retrain=retrain))
+    end
+    x′ = CounterfactualExplanations.counterfactual(counterfactual_explanation)
+    loss = map(eachslice(x′, dims=3)) do x
+        x = Matrix(x)
+        Δ = map(eachcol(conditional_samples[1])) do xsample
+            norm(x - xsample)
+        end
+        return mean(Δ)
     end
-    conditional_samples = 
+    loss = agg(loss)
+
+    return loss
+
 end
\ No newline at end of file
diff --git a/src/sampling.jl b/src/sampling.jl
index 8a6625ee..2cf346ca 100644
--- a/src/sampling.jl
+++ b/src/sampling.jl
@@ -13,7 +13,7 @@ end
 
 function EnergySampler(
     ce::CounterfactualExplanation;
-    opt::JointEnergyModels.AbstractSamplingRule=SGLD(),
+    opt::JointEnergyModels.AbstractSamplingRule=ImproperSGLD(),
     niter::Int=100,
     nsamples::Int=1000
 )
-- 
GitLab