From b7a78dfa2a1d155f19883cedaf2317c20bc69b23 Mon Sep 17 00:00:00 2001
From: pat-alt <altmeyerpat@gmail.com>
Date: Thu, 16 Feb 2023 10:49:48 +0100
Subject: [PATCH] basic set size loss penalty now for working counterfactual
 search

---
 Manifest.toml             |  4 ++--
 Project.toml              |  1 +
 notebooks/conformal.qmd   | 24 +++++++++++++++++++++---
 src/ConformalGenerator.jl | 17 +++++++++++++----
 src/model.jl              |  5 +++--
 5 files changed, 40 insertions(+), 11 deletions(-)

diff --git a/Manifest.toml b/Manifest.toml
index fea68c34..ec1a4d53 100644
--- a/Manifest.toml
+++ b/Manifest.toml
@@ -2,7 +2,7 @@
 
 julia_version = "1.8.5"
 manifest_format = "2.0"
-project_hash = "3488ba7b86c20f2db302dbcbc95e3f6227acdebb"
+project_hash = "b1b5b9a95ae3c44dd7fa2c6d1d451e9bca8b9297"
 
 [[deps.AbstractFFTs]]
 deps = ["ChainRulesCore", "LinearAlgebra"]
@@ -260,7 +260,7 @@ version = "0.6.2"
 deps = ["CSV", "CUDA", "CategoricalArrays", "DataFrames", "Flux", "LaplaceRedux", "LazyArtifacts", "LinearAlgebra", "MLDatasets", "MLJBase", "MLJModels", "MLUtils", "MultivariateStats", "NearestNeighborModels", "Parameters", "PkgTemplates", "Plots", "ProgressMeter", "Random", "Serialization", "SliceMap", "Statistics", "StatsBase", "Tables", "UMAP"]
 path = "../CounterfactualExplanations.jl"
 uuid = "2f13d31b-18db-44c1-bc43-ebaf2cff0be0"
-version = "0.1.5"
+version = "0.1.6"
 
 [[deps.Crayons]]
 git-tree-sha1 = "249fe38abf76d48563e2f4556bebd215aa317e15"
diff --git a/Project.toml b/Project.toml
index 2ff890fd..83402ab8 100644
--- a/Project.toml
+++ b/Project.toml
@@ -5,6 +5,7 @@ version = "0.1.0"
 
 [deps]
 CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597"
+ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2"
 ConformalPrediction = "98bfc277-1877-43dc-819b-a3e38c30242f"
 CounterfactualExplanations = "2f13d31b-18db-44c1-bc43-ebaf2cff0be0"
 Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
diff --git a/notebooks/conformal.qmd b/notebooks/conformal.qmd
index 5dcef400..1cea7016 100644
--- a/notebooks/conformal.qmd
+++ b/notebooks/conformal.qmd
@@ -35,7 +35,6 @@ fit!(mach)
 
 ```{julia}
 M = CCE.ConformalModel(conf_model, mach.fitresult)
-generator = CCE.ConformalGenerator()
 ```
 
 ```{julia}
@@ -51,5 +50,24 @@ plot(p1, p2, size=(800,320))
 ```
 
 ```{julia}
-ce = generate_counterfactual(x, target, counterfactual_data, M, generator)
-```
+#| output: true
+#| echo: false
+# Generators:
+generators = Dict(
+    "Generic (γ=0.5)" => GenericGenerator(opt = opt, decision_threshold=0.5),
+    "Generic (γ=0.9)" => GenericGenerator(opt = opt, decision_threshold=0.9),
+    "Conformal (λ₂=1)" => CCE.ConformalGenerator(opt=opt, λ=[0.1,1]),
+    "Conformal (λ₂=10)" => CCE.ConformalGenerator(opt=opt, λ=[0.1,10]),
+)
+
+
+counterfactuals = Dict([name => generate_counterfactual(x, target, counterfactual_data, M, gen;) for (name, gen) in generators])
+# Plots:
+plts = []
+for (name,ce) ∈ counterfactuals
+    plt = plot(ce; title=name, colorbar=false, ticks = false, legend=false, zoom=0)
+    plts = vcat(plts..., plt)
+end
+_n = length(generators)
+plot(plts..., size=(_n * 185,200), layout=(1,_n))
+```
\ No newline at end of file
diff --git a/src/ConformalGenerator.jl b/src/ConformalGenerator.jl
index b86bccc6..133a3362 100644
--- a/src/ConformalGenerator.jl
+++ b/src/ConformalGenerator.jl
@@ -13,6 +13,8 @@ mutable struct ConformalGenerator <: AbstractGradientBasedGenerator
     decision_threshold::Union{Nothing,AbstractFloat}
     opt::Flux.Optimise.AbstractOptimiser # optimizer
     Ï„::AbstractFloat # tolerance for convergence
+    κ::Real
+    temp::Real
 end
 
 # API streamlining:
@@ -48,7 +50,7 @@ function ConformalGenerator(;
     kwargs...,
 )
     params = ConformalGeneratorParams(; kwargs...)
-    ConformalGenerator(loss, complexity, λ, decision_threshold, params.opt, params.τ)
+    ConformalGenerator(loss, complexity, λ, decision_threshold, params.opt, params.τ, params.κ, params.temp)
 end
 
 """
@@ -68,6 +70,7 @@ function set_size_penalty(
     fitresult = counterfactual_explanation.M.fitresult
     X = CounterfactualExplanations.decode_state(counterfactual_explanation)
     loss = SliceMap.slicemap(X, dims=(1,2)) do x
+        x = Matrix(x)
         ConformalPrediction.smooth_size_loss(
             conf_model, fitresult, x;
             κ = generator.κ,
@@ -97,13 +100,19 @@ function Generators.h(
         CounterfactualExplanations.decode_state(counterfactual_explanation),
     )
 
-    # Euclidean norm of gradient:
-    Ω = set_size_penalty(generator, counterfactual_explanation)
-
+    # Set size penalty:
+    in_target_domain = all(target_probs(counterfactual_explanation) .>= 0.5)
+    if in_target_domain
+        Ω = set_size_penalty(generator, counterfactual_explanation)
+    else
+        Ω = 0
+    end
+    
     if length(generator.λ) == 1
         penalty = generator.λ * (dist_ .+ Ω)
     else
         penalty = generator.λ[1] * dist_ .+ generator.λ[2] * Ω
     end
+
     return penalty
 end
diff --git a/src/model.jl b/src/model.jl
index 654af645..eddd8b30 100644
--- a/src/model.jl
+++ b/src/model.jl
@@ -67,9 +67,10 @@ function Models.logits(M::ConformalModel, X::AbstractArray)
         #     return probas
         # end
         p̂ = fitresult[1](x)
-        if size(p̂, 2) > 1
-            p̂ = reduce(hcat, p̂)
+        if ndims(p̂) == 2
+            p̂ = [p̂]
         end
+        p̂ = reduce(hcat, p̂)
         ŷ = reduce(hcat, (map(p -> log.(p) .+ log(sum(exp.(p))), eachcol(p̂))))
         if M.likelihood == :classification_binary
             ŷ = reduce(hcat, (map(y -> y[2] - y[1], eachcol(ŷ))))
-- 
GitLab