Newer
Older
using ConformalPrediction
using CounterfactualExplanations.Models
using Flux
const CompatibleAtomicModel = Union{<:MLJFlux.MLJFluxProbabilistic,MLJEnsembles.ProbabilisticEnsembleModel{<:MLJFlux.MLJFluxProbabilistic}}
Pat Alt
committed
ConformalModel <: Models.AbstractDifferentiableModel
Pat Alt
committed
struct ConformalModel <: Models.AbstractDifferentiableModel
if likelihood ∈ [:classification_binary, :classification_multi] || isnothing(likelihood)
"`likelihood` should either be `nothing` or in `[:classification_binary,:classification_multi]`",
"""
_get_chains(fitresult)
Private function that extracts the chains from a fitted model.
"""
function _get_chains(fitresult)
chains = []
ignore_derivatives() do
if fitresult isa MLJEnsembles.WrappedEnsemble
_chains = map(res -> res[1], fitresult.ensemble)
else
_chains = [fitresult[1]]
end
push!(chains, _chains...)
return chains
end
"""
_outdim(fitresult)
Private function that extracts the output dimension from a fitted model.
"""
function _outdim(fitresult)
if fitresult isa MLJEnsembles.WrappedEnsemble
outdim = length(fitresult.ensemble[1][2])
else
outdim = length(fitresult[2])
end
return outdim
end
Private helper function that extracts the sampler from a fitted model.
"""
if hasfield(typeof(_mod), :model)
if _mod.model isa MLJEnsembles.EitherEnsembleModel
_mod = _mod.model
end
if _mod.model isa JointEnergyClassifier
sampler = _mod.model.sampler
else
sampler = false
end
Private helper function that checks if a fitted model has a sampler.
"""
"""
ConformalModel(model, fitresult=nothing; likelihood::Union{Nothing,Symbol}=nothing)
Outer constructor for `ConformalModel`. If `fitresult` is not specified, the model is not fitted and `likelihood` is inferred from the model. If `fitresult` is specified, `likelihood` is inferred from the output dimension of the model. If `likelihood` is not specified, it defaults to `:classification_binary`.
"""
function ConformalModel(model, fitresult=nothing; likelihood::Union{Nothing,Symbol}=nothing)
# Default to binary classification, if not specified or inferred:
"""
get_logits(f::Flux.Chain, x)
Helper function to return logits in case final layer is an activation function.
"""
get_logits(f::Flux.Chain, x) = f[end] isa Function ? f[1:end-1](x) : f(x)
@doc raw"""
Models.logits(M::ConformalModel, X::AbstractArray)
To keep things consistent with the architecture of `CounterfactualExplanations.jl`, this method computes logits $\beta_i x_i$ (i.e. the linear predictions) for a Conformal Classifier. By default, `MLJ.jl` and `ConformalPrediction.jl` return probabilistic predictions. To get the underlying logits, we invert the softmax function.
Let $\hat{p}_i$ denote the estimated softmax output for feature $i$. Then in the multi-class case the following formula can be applied:
```math
\beta_i x_i = \log (\hat{p}_i) + \log (\sum_i \exp(\hat{p}_i))
```
For a short derivation, see here: https://math.stackexchange.com/questions/2786600/invert-the-softmax-function.
In the binary case logits are fed through the sigmoid function instead of softmax, so we need to further adjust as follows,
```math
\beta x = \beta_1 x_1 - \beta_0 x_0
```
which follows from the derivation here: https://stats.stackexchange.com/questions/233658/softmax-vs-sigmoid-function-in-logistic-classifier
"""
function Models.logits(M::ConformalModel, X::AbstractArray)
ŷ = MLUtils.stack(map(chain -> get_logits(chain,x),_get_chains(fitresult))) |>
y -> mean(y, dims=ndims(y)) |>
y -> MLUtils.unstack(y, dims=ndims(y))[1]
if ndims(ŷ) == 2
ŷ = [ŷ]
"""
Models.probs(M::ConformalModel, X::AbstractArray)
Returns the estimated probabilities for a Conformal Classifier.
"""
end
"""
train(M::ConformalModel, data::CounterfactualData; kwrgs...)
Trains a Conformal Classifier `M` on `data`.
"""
function Models.train(M::ConformalModel, data::CounterfactualData; kwrgs...)
X, y = data.X, data.output_encoder.labels
X = table(permutedims(X))
conf_model = M.model
mach = machine(conf_model, X, y)
fit!(mach; kwrgs...)
likelihood, _ = CounterfactualExplanations.guess_likelihood(data.output_encoder.y)
return ConformalModel(mach.model, mach.fitresult, likelihood)