From b7a78dfa2a1d155f19883cedaf2317c20bc69b23 Mon Sep 17 00:00:00 2001 From: pat-alt <altmeyerpat@gmail.com> Date: Thu, 16 Feb 2023 10:49:48 +0100 Subject: [PATCH] basic set size loss penalty now for working counterfactual search --- Manifest.toml | 4 ++-- Project.toml | 1 + notebooks/conformal.qmd | 24 +++++++++++++++++++++--- src/ConformalGenerator.jl | 17 +++++++++++++---- src/model.jl | 5 +++-- 5 files changed, 40 insertions(+), 11 deletions(-) diff --git a/Manifest.toml b/Manifest.toml index fea68c34..ec1a4d53 100644 --- a/Manifest.toml +++ b/Manifest.toml @@ -2,7 +2,7 @@ julia_version = "1.8.5" manifest_format = "2.0" -project_hash = "3488ba7b86c20f2db302dbcbc95e3f6227acdebb" +project_hash = "b1b5b9a95ae3c44dd7fa2c6d1d451e9bca8b9297" [[deps.AbstractFFTs]] deps = ["ChainRulesCore", "LinearAlgebra"] @@ -260,7 +260,7 @@ version = "0.6.2" deps = ["CSV", "CUDA", "CategoricalArrays", "DataFrames", "Flux", "LaplaceRedux", "LazyArtifacts", "LinearAlgebra", "MLDatasets", "MLJBase", "MLJModels", "MLUtils", "MultivariateStats", "NearestNeighborModels", "Parameters", "PkgTemplates", "Plots", "ProgressMeter", "Random", "Serialization", "SliceMap", "Statistics", "StatsBase", "Tables", "UMAP"] path = "../CounterfactualExplanations.jl" uuid = "2f13d31b-18db-44c1-bc43-ebaf2cff0be0" -version = "0.1.5" +version = "0.1.6" [[deps.Crayons]] git-tree-sha1 = "249fe38abf76d48563e2f4556bebd215aa317e15" diff --git a/Project.toml b/Project.toml index 2ff890fd..83402ab8 100644 --- a/Project.toml +++ b/Project.toml @@ -5,6 +5,7 @@ version = "0.1.0" [deps] CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597" +ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2" ConformalPrediction = "98bfc277-1877-43dc-819b-a3e38c30242f" CounterfactualExplanations = "2f13d31b-18db-44c1-bc43-ebaf2cff0be0" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" diff --git a/notebooks/conformal.qmd b/notebooks/conformal.qmd index 5dcef400..1cea7016 100644 --- a/notebooks/conformal.qmd +++ b/notebooks/conformal.qmd @@ -35,7 +35,6 @@ fit!(mach) ```{julia} M = CCE.ConformalModel(conf_model, mach.fitresult) -generator = CCE.ConformalGenerator() ``` ```{julia} @@ -51,5 +50,24 @@ plot(p1, p2, size=(800,320)) ``` ```{julia} -ce = generate_counterfactual(x, target, counterfactual_data, M, generator) -``` +#| output: true +#| echo: false +# Generators: +generators = Dict( + "Generic (γ=0.5)" => GenericGenerator(opt = opt, decision_threshold=0.5), + "Generic (γ=0.9)" => GenericGenerator(opt = opt, decision_threshold=0.9), + "Conformal (λ₂=1)" => CCE.ConformalGenerator(opt=opt, λ=[0.1,1]), + "Conformal (λ₂=10)" => CCE.ConformalGenerator(opt=opt, λ=[0.1,10]), +) + + +counterfactuals = Dict([name => generate_counterfactual(x, target, counterfactual_data, M, gen;) for (name, gen) in generators]) +# Plots: +plts = [] +for (name,ce) ∈ counterfactuals + plt = plot(ce; title=name, colorbar=false, ticks = false, legend=false, zoom=0) + plts = vcat(plts..., plt) +end +_n = length(generators) +plot(plts..., size=(_n * 185,200), layout=(1,_n)) +``` \ No newline at end of file diff --git a/src/ConformalGenerator.jl b/src/ConformalGenerator.jl index b86bccc6..133a3362 100644 --- a/src/ConformalGenerator.jl +++ b/src/ConformalGenerator.jl @@ -13,6 +13,8 @@ mutable struct ConformalGenerator <: AbstractGradientBasedGenerator decision_threshold::Union{Nothing,AbstractFloat} opt::Flux.Optimise.AbstractOptimiser # optimizer τ::AbstractFloat # tolerance for convergence + κ::Real + temp::Real end # API streamlining: @@ -48,7 +50,7 @@ function ConformalGenerator(; kwargs..., ) params = ConformalGeneratorParams(; kwargs...) - ConformalGenerator(loss, complexity, λ, decision_threshold, params.opt, params.τ) + ConformalGenerator(loss, complexity, λ, decision_threshold, params.opt, params.τ, params.κ, params.temp) end """ @@ -68,6 +70,7 @@ function set_size_penalty( fitresult = counterfactual_explanation.M.fitresult X = CounterfactualExplanations.decode_state(counterfactual_explanation) loss = SliceMap.slicemap(X, dims=(1,2)) do x + x = Matrix(x) ConformalPrediction.smooth_size_loss( conf_model, fitresult, x; κ = generator.κ, @@ -97,13 +100,19 @@ function Generators.h( CounterfactualExplanations.decode_state(counterfactual_explanation), ) - # Euclidean norm of gradient: - Ω = set_size_penalty(generator, counterfactual_explanation) - + # Set size penalty: + in_target_domain = all(target_probs(counterfactual_explanation) .>= 0.5) + if in_target_domain + Ω = set_size_penalty(generator, counterfactual_explanation) + else + Ω = 0 + end + if length(generator.λ) == 1 penalty = generator.λ * (dist_ .+ Ω) else penalty = generator.λ[1] * dist_ .+ generator.λ[2] * Ω end + return penalty end diff --git a/src/model.jl b/src/model.jl index 654af645..eddd8b30 100644 --- a/src/model.jl +++ b/src/model.jl @@ -67,9 +67,10 @@ function Models.logits(M::ConformalModel, X::AbstractArray) # return probas # end p̂ = fitresult[1](x) - if size(p̂, 2) > 1 - p̂ = reduce(hcat, p̂) + if ndims(p̂) == 2 + p̂ = [p̂] end + p̂ = reduce(hcat, p̂) ŷ = reduce(hcat, (map(p -> log.(p) .+ log(sum(exp.(p))), eachcol(p̂)))) if M.likelihood == :classification_binary ŷ = reduce(hcat, (map(y -> y[2] - y[1], eachcol(ŷ)))) -- GitLab