Skip to content
Snippets Groups Projects
Commit f2512bf3 authored by pat-alt's avatar pat-alt
Browse files

penalties

parent 2f60dc80
No related branches found
No related tags found
No related merge requests found
......@@ -4,7 +4,10 @@ using CCE
using ConformalPrediction
using CounterfactualExplanations
using CounterfactualExplanations.Data
using Flux
using MLJBase
using MLJFlux
using Plots
```
```{julia}
......@@ -12,7 +15,11 @@ counterfactual_data = load_linearly_separable()
```
```{julia}
clf = NeuralNetworkClassifier()
builder = MLJFlux.@builder Chain(
Dense(n_in, 32, relu),
Dense(32, n_out)
)
clf = NeuralNetworkClassifier(builder=builder, epochs=100)
```
......@@ -23,3 +30,8 @@ conf_model = conformal_model(clf; method=:simple_inductive)
mach = machine(conf_model, X, y)
fit!(mach)
```
```{julia}
M = CCE.ConformalModel(conf_model, mach.fitresult)
```
module CCE
import MLJModelInterface as MMI
include("model.jl")
include("ConformalGenerator.jl")
......
......@@ -48,6 +48,31 @@ function ConformalGenerator(;
ConformalGenerator(loss, complexity, λ, decision_threshold, params.opt, params.τ)
end
# Loss:
# """
# ℓ(generator::ConformalGenerator, counterfactual_explanation::AbstractCounterfactualExplanation)
# The default method to apply the generator loss function to the current counterfactual state for any generator.
# """
# function ℓ(
# generator::ConformalGenerator,
# counterfactual_explanation::AbstractCounterfactualExplanation,
# )
# loss_fun =
# !isnothing(generator.loss) ? getfield(Losses, generator.loss) :
# CounterfactualExplanations.guess_loss(counterfactual_explanation)
# @assert !isnothing(loss_fun) "No loss function provided and loss function could not be guessed based on model."
# loss = loss_fun(
# getfield(Models, :logits)(
# counterfactual_explanation.M,
# CounterfactualExplanations.decode_state(counterfactual_explanation),
# ),
# counterfactual_explanation.target_encoded,
# )
# return loss
# end
"""
set_size_penalty(
generator::ConformalGenerator,
......@@ -61,21 +86,7 @@ function set_size_penalty(
counterfactual_explanation::AbstractCounterfactualExplanation,
)
x_ = CounterfactualExplanations.decode_state(counterfactual_explanation)
M = counterfactual_explanation.M
model = isa(M.model, Vector) ? M.model : [M.model]
y_ = counterfactual_explanation.target_encoded
if M.likelihood == :classification_binary
loss_type = :logitbinarycrossentropy
else
loss_type = :logitcrossentropy
end
loss(x, y) =
sum([getfield(Flux.Losses, loss_type)(nn(x), y) for nn in model]) / length(model)
return loss(x_, y_)
end
# Complexity:
......
using ConformalPrediction
using CounterfactualExplanations.Models
using Flux
using MLJBase
using MLUtils
using SliceMap
using Statistics
......@@ -12,10 +13,11 @@ Constructor for models trained in `Flux.jl`.
"""
struct ConformalModel <: Models.AbstractDifferentiableJuliaModel
model::ConformalPrediction.ConformalProbabilisticSet
fitresult::Any
likelihood::Symbol
function ConformalModel(model, likelihood)
function ConformalModel(model, fitresult, likelihood)
if likelihood [:classification_binary, :classification_multi]
new(model, likelihood)
new(model, fitresult, likelihood)
else
throw(
ArgumentError(
......@@ -27,20 +29,57 @@ struct ConformalModel <: Models.AbstractDifferentiableJuliaModel
end
# Outer constructor method:
function ConformalModel(model; likelihood::Symbol=:classification_binary)
ConformalModel(model, likelihood)
function ConformalModel(model, fitresult; likelihood::Symbol=:classification_binary)
ConformalModel(model, fitresult, likelihood)
end
# Methods
function logits(M::ConformalModel, X::AbstractArray)
return SliceMap.slicemap(x -> M.model(x), X, dims=(1, 2))
@doc raw"""
Models.logits(M::ConformalModel, X::AbstractArray)
To keep things consistent with the architecture of `CounterfactualExplanations.jl`, this method computes logits $\beta_i x_i$ (i.e. the linear predictions) for a Conformal Classifier. By default, `MLJ.jl` and `ConformalPrediction.jl` return probabilistic predictions. To get the underlying logits, we invert the softmax function.
Let $\hat{p}_i$ denote the estimated softmax output for feature $i$. Then in the multi-class case the following formula can be applied:
```math
\beta_i x_i = \log (\hat{p}_i) + \log (\sum_i \exp(\hat{p}_i))
```
For a short derivation, see here: https://math.stackexchange.com/questions/2786600/invert-the-softmax-function.
In the binary case logits are fed through the sigmoid function instead of softmax, so we need to further adjust as follows,
```math
\beta x = \beta_1 x_1 - \beta_0 x_0
```
which follows from the derivation here: https://stats.stackexchange.com/questions/233658/softmax-vs-sigmoid-function-in-logistic-classifier
"""
function Models.logits(M::ConformalModel, X::AbstractArray)
yhat = SliceMap.slicemap(X, dims=(1, 2)) do x
conf_model = M.model
fitresult = M.fitresult
X = MLJBase.table(permutedims(X))
= MMI.predict(conf_model.model, fitresult, MMI.reformat(conf_model.model, X)...)
= map() do pp
L = .decoder.classes
probas = pdf.(pp, L)
return probas
end
= reduce(hcat, )
= reduce(hcat, (map(p -> log.(p) .+ log(sum(exp.(p))), eachcol())))
if M.likelihood == :classification_binary
= reduce(hcat, (map(y -> y[2] - y[1], eachcol())))
end
end
return yhat
end
function probs(M::ConformalModel, X::AbstractArray)
function Models.probs(M::ConformalModel, X::AbstractArray)
if M.likelihood == :classification_binary
output = σ.(logits(M, X))
output = σ.(Models.logits(M, X))
elseif M.likelihood == :classification_multi
output = softmax(logits(M, X))
output = softmax(Models.logits(M, X))
end
return output
end
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment