diff --git a/notebooks/conformal.qmd b/notebooks/conformal.qmd index 015e7e52c7431dbe684b21618f48a441025285c5..57b3f6d79a9c51f6134e9600f28c089d654be724 100644 --- a/notebooks/conformal.qmd +++ b/notebooks/conformal.qmd @@ -4,7 +4,10 @@ using CCE using ConformalPrediction using CounterfactualExplanations using CounterfactualExplanations.Data +using Flux +using MLJBase using MLJFlux +using Plots ``` ```{julia} @@ -12,7 +15,11 @@ counterfactual_data = load_linearly_separable() ``` ```{julia} -clf = NeuralNetworkClassifier() +builder = MLJFlux.@builder Chain( + Dense(n_in, 32, relu), + Dense(32, n_out) +) +clf = NeuralNetworkClassifier(builder=builder, epochs=100) ``` @@ -23,3 +30,8 @@ conf_model = conformal_model(clf; method=:simple_inductive) mach = machine(conf_model, X, y) fit!(mach) ``` + +```{julia} +M = CCE.ConformalModel(conf_model, mach.fitresult) +``` + diff --git a/src/CCE.jl b/src/CCE.jl index 58f70b9cdc1699e50aa93f9a2a413a221ed405e5..e0fe38fa371838c1d18b50807ab0dd3004894e45 100644 --- a/src/CCE.jl +++ b/src/CCE.jl @@ -1,5 +1,7 @@ module CCE +import MLJModelInterface as MMI + include("model.jl") include("ConformalGenerator.jl") diff --git a/src/ConformalGenerator.jl b/src/ConformalGenerator.jl index c772b69a64248905d8531fabf67383f26af41ed9..540d71a6b64d446bcbdbb219cdca1aaeaee89a80 100644 --- a/src/ConformalGenerator.jl +++ b/src/ConformalGenerator.jl @@ -48,6 +48,31 @@ function ConformalGenerator(; ConformalGenerator(loss, complexity, λ, decision_threshold, params.opt, params.τ) end +# Loss: +# """ +# ℓ(generator::ConformalGenerator, counterfactual_explanation::AbstractCounterfactualExplanation) + +# The default method to apply the generator loss function to the current counterfactual state for any generator. +# """ +# function ℓ( +# generator::ConformalGenerator, +# counterfactual_explanation::AbstractCounterfactualExplanation, +# ) + +# loss_fun = +# !isnothing(generator.loss) ? getfield(Losses, generator.loss) : +# CounterfactualExplanations.guess_loss(counterfactual_explanation) +# @assert !isnothing(loss_fun) "No loss function provided and loss function could not be guessed based on model." +# loss = loss_fun( +# getfield(Models, :logits)( +# counterfactual_explanation.M, +# CounterfactualExplanations.decode_state(counterfactual_explanation), +# ), +# counterfactual_explanation.target_encoded, +# ) +# return loss +# end + """ set_size_penalty( generator::ConformalGenerator, @@ -61,21 +86,7 @@ function set_size_penalty( counterfactual_explanation::AbstractCounterfactualExplanation, ) - x_ = CounterfactualExplanations.decode_state(counterfactual_explanation) - M = counterfactual_explanation.M - model = isa(M.model, Vector) ? M.model : [M.model] - y_ = counterfactual_explanation.target_encoded - - if M.likelihood == :classification_binary - loss_type = :logitbinarycrossentropy - else - loss_type = :logitcrossentropy - end - - loss(x, y) = - sum([getfield(Flux.Losses, loss_type)(nn(x), y) for nn in model]) / length(model) - return loss(x_, y_) end # Complexity: diff --git a/src/model.jl b/src/model.jl index dc7e8a4ac612433cba9946acafad36602234d3bd..2abc8fa4404c537c32c1ff496d8f7acae4397549 100644 --- a/src/model.jl +++ b/src/model.jl @@ -1,6 +1,7 @@ using ConformalPrediction using CounterfactualExplanations.Models using Flux +using MLJBase using MLUtils using SliceMap using Statistics @@ -12,10 +13,11 @@ Constructor for models trained in `Flux.jl`. """ struct ConformalModel <: Models.AbstractDifferentiableJuliaModel model::ConformalPrediction.ConformalProbabilisticSet + fitresult::Any likelihood::Symbol - function ConformalModel(model, likelihood) + function ConformalModel(model, fitresult, likelihood) if likelihood ∈ [:classification_binary, :classification_multi] - new(model, likelihood) + new(model, fitresult, likelihood) else throw( ArgumentError( @@ -27,20 +29,57 @@ struct ConformalModel <: Models.AbstractDifferentiableJuliaModel end # Outer constructor method: -function ConformalModel(model; likelihood::Symbol=:classification_binary) - ConformalModel(model, likelihood) +function ConformalModel(model, fitresult; likelihood::Symbol=:classification_binary) + ConformalModel(model, fitresult, likelihood) end # Methods -function logits(M::ConformalModel, X::AbstractArray) - return SliceMap.slicemap(x -> M.model(x), X, dims=(1, 2)) +@doc raw""" + Models.logits(M::ConformalModel, X::AbstractArray) + +To keep things consistent with the architecture of `CounterfactualExplanations.jl`, this method computes logits $\beta_i x_i$ (i.e. the linear predictions) for a Conformal Classifier. By default, `MLJ.jl` and `ConformalPrediction.jl` return probabilistic predictions. To get the underlying logits, we invert the softmax function. + +Let $\hat{p}_i$ denote the estimated softmax output for feature $i$. Then in the multi-class case the following formula can be applied: + +```math +\beta_i x_i = \log (\hat{p}_i) + \log (\sum_i \exp(\hat{p}_i)) +``` + +For a short derivation, see here: https://math.stackexchange.com/questions/2786600/invert-the-softmax-function. + +In the binary case logits are fed through the sigmoid function instead of softmax, so we need to further adjust as follows, + +```math +\beta x = \beta_1 x_1 - \beta_0 x_0 +``` + +which follows from the derivation here: https://stats.stackexchange.com/questions/233658/softmax-vs-sigmoid-function-in-logistic-classifier +""" +function Models.logits(M::ConformalModel, X::AbstractArray) + yhat = SliceMap.slicemap(X, dims=(1, 2)) do x + conf_model = M.model + fitresult = M.fitresult + X = MLJBase.table(permutedims(X)) + p̂ = MMI.predict(conf_model.model, fitresult, MMI.reformat(conf_model.model, X)...) + p̂ = map(p̂) do pp + L = p̂.decoder.classes + probas = pdf.(pp, L) + return probas + end + p̂ = reduce(hcat, p̂) + ŷ = reduce(hcat, (map(p -> log.(p) .+ log(sum(exp.(p))), eachcol(p̂)))) + if M.likelihood == :classification_binary + p̂ = reduce(hcat, (map(y -> y[2] - y[1], eachcol(ŷ)))) + end + end + return yhat end -function probs(M::ConformalModel, X::AbstractArray) +function Models.probs(M::ConformalModel, X::AbstractArray) if M.likelihood == :classification_binary - output = σ.(logits(M, X)) + output = σ.(Models.logits(M, X)) elseif M.likelihood == :classification_multi - output = softmax(logits(M, X)) + output = softmax(Models.logits(M, X)) end return output end \ No newline at end of file