From adf59a74dad4bbe479b7a0979b522897e3e12be7 Mon Sep 17 00:00:00 2001 From: pat-alt <altmeyerpat@gmail.com> Date: Wed, 15 Feb 2023 18:57:19 +0100 Subject: [PATCH] slowly slowly --- notebooks/conformal.qmd | 13 ++++++++++ src/ConformalGenerator.jl | 52 ++++++++++++++------------------------- src/model.jl | 7 +++--- 3 files changed, 35 insertions(+), 37 deletions(-) diff --git a/notebooks/conformal.qmd b/notebooks/conformal.qmd index 57b3f6d7..9c45239d 100644 --- a/notebooks/conformal.qmd +++ b/notebooks/conformal.qmd @@ -31,7 +31,20 @@ mach = machine(conf_model, X, y) fit!(mach) ``` +```{julia} +contourf(mach.model, mach.fitresult, X, y; plot_set_size=true) +``` + +## Counterfactual Explanation + + ```{julia} M = CCE.ConformalModel(conf_model, mach.fitresult) +generator = CCE.ConformalGenerator() ``` +```{julia} +x = select_factual(counterfactual_data,rand(1:size(counterfactual_data.X,2))) +y = predict_label(M, counterfactual_data, x)[1] +target = counterfactual_data.y_levels[counterfactual_data.y_levels .!= y][1] +``` \ No newline at end of file diff --git a/src/ConformalGenerator.jl b/src/ConformalGenerator.jl index bde032bd..b86bccc6 100644 --- a/src/ConformalGenerator.jl +++ b/src/ConformalGenerator.jl @@ -3,6 +3,7 @@ using CounterfactualExplanations.Generators using Flux using LinearAlgebra using Parameters +using SliceMap using Statistics mutable struct ConformalGenerator <: AbstractGradientBasedGenerator @@ -19,7 +20,7 @@ end opt::Flux.Optimise.AbstractOptimiser = Descent() τ::AbstractFloat = 1e-3 κ::Real = 1.0 - Τ::Real = 0.5 + temp::Real = 0.5 end """ @@ -50,31 +51,6 @@ function ConformalGenerator(; ConformalGenerator(loss, complexity, λ, decision_threshold, params.opt, params.τ) end -# Loss: -# """ -# ℓ(generator::ConformalGenerator, counterfactual_explanation::AbstractCounterfactualExplanation) - -# The default method to apply the generator loss function to the current counterfactual state for any generator. -# """ -# function ℓ( -# generator::ConformalGenerator, -# counterfactual_explanation::AbstractCounterfactualExplanation, -# ) - -# loss_fun = -# !isnothing(generator.loss) ? getfield(Losses, generator.loss) : -# CounterfactualExplanations.guess_loss(counterfactual_explanation) -# @assert !isnothing(loss_fun) "No loss function provided and loss function could not be guessed based on model." -# loss = loss_fun( -# getfield(Models, :logits)( -# counterfactual_explanation.M, -# CounterfactualExplanations.decode_state(counterfactual_explanation), -# ), -# counterfactual_explanation.target_encoded, -# ) -# return loss -# end - """ set_size_penalty( generator::ConformalGenerator, @@ -88,6 +64,19 @@ function set_size_penalty( counterfactual_explanation::AbstractCounterfactualExplanation, ) + conf_model = counterfactual_explanation.M.model + fitresult = counterfactual_explanation.M.fitresult + X = CounterfactualExplanations.decode_state(counterfactual_explanation) + loss = SliceMap.slicemap(X, dims=(1,2)) do x + ConformalPrediction.smooth_size_loss( + conf_model, fitresult, x; + κ = generator.κ, + temp = generator.temp + ) + end + loss = mean(loss) + + return loss end @@ -109,17 +98,12 @@ function Generators.h( ) # Euclidean norm of gradient: - in_target_domain = all(target_probs(counterfactual_explanation) .>= 0.5) - if in_target_domain - grad_norm = gradient_penalty(generator, counterfactual_explanation) - else - grad_norm = 0 - end + Ω = set_size_penalty(generator, counterfactual_explanation) if length(generator.λ) == 1 - penalty = generator.λ * (dist_ .+ grad_norm) + penalty = generator.λ * (dist_ .+ Ω) else - penalty = generator.λ[1] * dist_ .+ generator.λ[2] * grad_norm + penalty = generator.λ[1] * dist_ .+ generator.λ[2] * Ω end return penalty end diff --git a/src/model.jl b/src/model.jl index 2abc8fa4..1db725d1 100644 --- a/src/model.jl +++ b/src/model.jl @@ -59,8 +59,8 @@ function Models.logits(M::ConformalModel, X::AbstractArray) yhat = SliceMap.slicemap(X, dims=(1, 2)) do x conf_model = M.model fitresult = M.fitresult - X = MLJBase.table(permutedims(X)) - p̂ = MMI.predict(conf_model.model, fitresult, MMI.reformat(conf_model.model, X)...) + x = MLJBase.table(permutedims(x)) + p̂ = MMI.predict(conf_model.model, fitresult, MMI.reformat(conf_model.model, x)...) p̂ = map(p̂) do pp L = p̂.decoder.classes probas = pdf.(pp, L) @@ -69,8 +69,9 @@ function Models.logits(M::ConformalModel, X::AbstractArray) p̂ = reduce(hcat, p̂) ŷ = reduce(hcat, (map(p -> log.(p) .+ log(sum(exp.(p))), eachcol(p̂)))) if M.likelihood == :classification_binary - p̂ = reduce(hcat, (map(y -> y[2] - y[1], eachcol(ŷ)))) + ŷ = reduce(hcat, (map(y -> y[2] - y[1], eachcol(ŷ)))) end + return ŷ end return yhat end -- GitLab